From 32cb41eb40b46cc78a6f9dbf9ba86f933afc41e1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Tue, 9 Apr 2024 17:41:43 +0200
Subject: [PATCH] [NN_printer] Remove invariant on the position of meta

---
 src/meta.ml             | 18 ++++++++++++++++++
 src/meta.mli            |  3 +++
 src/printers/marabou.ml | 10 ++--------
 src/printers/pyrat.ml   | 10 ++--------
 src/printers/vnnlib.ml  | 12 ++----------
 5 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/src/meta.ml b/src/meta.ml
index ae6eb30..b556090 100644
--- a/src/meta.ml
+++ b/src/meta.ml
@@ -49,3 +49,21 @@ let meta_dataset_filename =
   Why3.Theory.(
     register_meta_excl "caisar_dataset"
       ~desc:"Indicates the filename of the dataset" [ MTstring ])
+
+let rec get_io_meta ~input_name ~output_name info task =
+  match task with
+  | None -> ()
+  | Some { Why3.Task.task_prev; task_decl; _ } -> (
+    get_io_meta ~input_name ~output_name info task_prev;
+    match task_decl.Why3.Theory.td_node with
+    | Use _ | Clone _ -> ()
+    | Meta (meta, l) when Why3.Theory.meta_equal meta meta_input -> (
+      match l with
+      | [ MAls ls; MAint i ] -> Why3.Term.Hls.add info ls (input_name i)
+      | _ -> assert false)
+    | Meta (meta, l) when Why3.Theory.meta_equal meta meta_output -> (
+      match l with
+      | [ MAls ls; MAint i ] -> Why3.Term.Hls.add info ls (output_name i)
+      | _ -> assert false)
+    | Meta _ -> ()
+    | Decl _ -> ())
diff --git a/src/meta.mli b/src/meta.mli
index 0ed5f6a..c15dd4c 100644
--- a/src/meta.mli
+++ b/src/meta.mli
@@ -34,3 +34,6 @@ val meta_svm_filename : Why3.Theory.meta
 
 val meta_dataset_filename : Why3.Theory.meta
 (** The filename of the dataset. *)
+
+val get_io_meta: input_name:(int -> 'a) -> output_name:(int -> 'a) -> 'a Why3.Term.Hls.t -> Why3.Task.task -> unit
+(** Add to the given hashtbl all the name for the inputs and outputs *)
\ No newline at end of file
diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml
index 052964a..1f0a06b 100644
--- a/src/printers/marabou.ml
+++ b/src/printers/marabou.ml
@@ -169,14 +169,6 @@ let rec print_tdecl info fmt task =
     print_tdecl info fmt task_prev;
     match task_decl.Theory.td_node with
     | Use _ | Clone _ -> ()
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> (
-      match l with
-      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
-      | _ -> assert false)
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> (
-      match l with
-      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
-      | _ -> assert false)
     | Meta (_, _) -> ()
     | Decl d -> print_decl info fmt d)
 
@@ -198,6 +190,8 @@ let print_task args ?old:_ fmt task =
     }
   in
   Printer.print_prelude fmt args.Printer.prelude;
+  Meta.get_io_meta info.variables task ~input_name:(Fmt.str "x%i")
+    ~output_name:(Fmt.str "y%i");
   print_tdecl info fmt task
 
 let init () =
diff --git a/src/printers/pyrat.ml b/src/printers/pyrat.ml
index 162d91e..ef66f0d 100644
--- a/src/printers/pyrat.ml
+++ b/src/printers/pyrat.ml
@@ -155,14 +155,6 @@ let rec print_tdecl info fmt task =
     print_tdecl info fmt task_prev;
     match task_decl.Theory.td_node with
     | Use _ | Clone _ -> ()
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> (
-      match l with
-      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
-      | _ -> assert false)
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> (
-      match l with
-      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
-      | _ -> assert false)
     | Meta (_, _) -> ()
     | Decl d -> print_decl info fmt d)
 
@@ -184,6 +176,8 @@ let print_task args ?old:_ fmt task =
     }
   in
   Printer.print_prelude fmt args.Printer.prelude;
+  Meta.get_io_meta info.variables task ~input_name:(Fmt.str "x%i")
+    ~output_name:(Fmt.str "y%i");
   print_tdecl info fmt task
 
 let init () =
diff --git a/src/printers/vnnlib.ml b/src/printers/vnnlib.ml
index 2bad9de..072339d 100644
--- a/src/printers/vnnlib.ml
+++ b/src/printers/vnnlib.ml
@@ -502,16 +502,6 @@ let rec print_tdecl info fmt task =
     print_tdecl info fmt task_prev;
     match task_decl.Theory.td_node with
     | Use _ | Clone _ -> ()
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> (
-      match l with
-      | [ MAls ls; MAint i ] ->
-        Term.Hls.add info.variables ls (Fmt.str "X_%i" i)
-      | _ -> assert false)
-    | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> (
-      match l with
-      | [ MAls ls; MAint i ] ->
-        Term.Hls.add info.variables ls (Fmt.str "Y_%i" i)
-      | _ -> assert false)
     | Meta _ -> ()
     | Decl d -> print_decl info fmt d)
 
@@ -534,6 +524,8 @@ let print_task args ?old:_ fmt task =
     }
   in
   Printer.print_prelude fmt args.Printer.prelude;
+  Meta.get_io_meta info.variables task ~input_name:(Fmt.str "X_%i")
+    ~output_name:(Fmt.str "Y_%i");
   print_tdecl info fmt task
 
 let init () =
-- 
GitLab