From 32cb41eb40b46cc78a6f9dbf9ba86f933afc41e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Tue, 9 Apr 2024 17:41:43 +0200 Subject: [PATCH] [NN_printer] Remove invariant on the position of meta --- src/meta.ml | 18 ++++++++++++++++++ src/meta.mli | 3 +++ src/printers/marabou.ml | 10 ++-------- src/printers/pyrat.ml | 10 ++-------- src/printers/vnnlib.ml | 12 ++---------- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/meta.ml b/src/meta.ml index ae6eb30..b556090 100644 --- a/src/meta.ml +++ b/src/meta.ml @@ -49,3 +49,21 @@ let meta_dataset_filename = Why3.Theory.( register_meta_excl "caisar_dataset" ~desc:"Indicates the filename of the dataset" [ MTstring ]) + +let rec get_io_meta ~input_name ~output_name info task = + match task with + | None -> () + | Some { Why3.Task.task_prev; task_decl; _ } -> ( + get_io_meta ~input_name ~output_name info task_prev; + match task_decl.Why3.Theory.td_node with + | Use _ | Clone _ -> () + | Meta (meta, l) when Why3.Theory.meta_equal meta meta_input -> ( + match l with + | [ MAls ls; MAint i ] -> Why3.Term.Hls.add info ls (input_name i) + | _ -> assert false) + | Meta (meta, l) when Why3.Theory.meta_equal meta meta_output -> ( + match l with + | [ MAls ls; MAint i ] -> Why3.Term.Hls.add info ls (output_name i) + | _ -> assert false) + | Meta _ -> () + | Decl _ -> ()) diff --git a/src/meta.mli b/src/meta.mli index 0ed5f6a..c15dd4c 100644 --- a/src/meta.mli +++ b/src/meta.mli @@ -34,3 +34,6 @@ val meta_svm_filename : Why3.Theory.meta val meta_dataset_filename : Why3.Theory.meta (** The filename of the dataset. *) + +val get_io_meta: input_name:(int -> 'a) -> output_name:(int -> 'a) -> 'a Why3.Term.Hls.t -> Why3.Task.task -> unit +(** Add to the given hashtbl all the name for the inputs and outputs *) \ No newline at end of file diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml index 052964a..1f0a06b 100644 --- a/src/printers/marabou.ml +++ b/src/printers/marabou.ml @@ -169,14 +169,6 @@ let rec print_tdecl info fmt task = print_tdecl info fmt task_prev; match task_decl.Theory.td_node with | Use _ | Clone _ -> () - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> ( - match l with - | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i) - | _ -> assert false) - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> ( - match l with - | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i) - | _ -> assert false) | Meta (_, _) -> () | Decl d -> print_decl info fmt d) @@ -198,6 +190,8 @@ let print_task args ?old:_ fmt task = } in Printer.print_prelude fmt args.Printer.prelude; + Meta.get_io_meta info.variables task ~input_name:(Fmt.str "x%i") + ~output_name:(Fmt.str "y%i"); print_tdecl info fmt task let init () = diff --git a/src/printers/pyrat.ml b/src/printers/pyrat.ml index 162d91e..ef66f0d 100644 --- a/src/printers/pyrat.ml +++ b/src/printers/pyrat.ml @@ -155,14 +155,6 @@ let rec print_tdecl info fmt task = print_tdecl info fmt task_prev; match task_decl.Theory.td_node with | Use _ | Clone _ -> () - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> ( - match l with - | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i) - | _ -> assert false) - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> ( - match l with - | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i) - | _ -> assert false) | Meta (_, _) -> () | Decl d -> print_decl info fmt d) @@ -184,6 +176,8 @@ let print_task args ?old:_ fmt task = } in Printer.print_prelude fmt args.Printer.prelude; + Meta.get_io_meta info.variables task ~input_name:(Fmt.str "x%i") + ~output_name:(Fmt.str "y%i"); print_tdecl info fmt task let init () = diff --git a/src/printers/vnnlib.ml b/src/printers/vnnlib.ml index 2bad9de..072339d 100644 --- a/src/printers/vnnlib.ml +++ b/src/printers/vnnlib.ml @@ -502,16 +502,6 @@ let rec print_tdecl info fmt task = print_tdecl info fmt task_prev; match task_decl.Theory.td_node with | Use _ | Clone _ -> () - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> ( - match l with - | [ MAls ls; MAint i ] -> - Term.Hls.add info.variables ls (Fmt.str "X_%i" i) - | _ -> assert false) - | Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> ( - match l with - | [ MAls ls; MAint i ] -> - Term.Hls.add info.variables ls (Fmt.str "Y_%i" i) - | _ -> assert false) | Meta _ -> () | Decl d -> print_decl info fmt d) @@ -534,6 +524,8 @@ let print_task args ?old:_ fmt task = } in Printer.print_prelude fmt args.Printer.prelude; + Meta.get_io_meta info.variables task ~input_name:(Fmt.str "X_%i") + ~output_name:(Fmt.str "Y_%i"); print_tdecl info fmt task let init () = -- GitLab