From fd43351421e7f30b5f8c94ee7e67e44934c2a4fa Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Wed, 24 Jan 2024 16:58:36 +0100
Subject: [PATCH] Manage multiple registered NIER

---
 src/main.ml               | 14 +++++++-------
 src/verification.ml       | 22 +++++++++++++++-------
 src/verification.mli      |  6 +++---
 tests/bin/inspect_onnx.py | 13 ++++++++++---
 tests/nier_to_onnx.t      |  4 ++--
 5 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/src/main.ml b/src/main.ml
index e1b74da..94ef85b 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -123,14 +123,14 @@ let log_theory_answer =
           additional_info)))
 
 let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
-  ?def_constants ?theories ?goals ?onnx_out_file files =
+  ?def_constants ?theories ?goals ?onnx_out_dir files =
   let memlimit = Option.map memlimit ~f:memlimit_of_string in
   let timelimit = Option.map timelimit ~f:timelimit_of_string in
   let theory_answers =
     List.map files
       ~f:
         (Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset
-           prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_file)
+           prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_dir)
   in
   List.iter theory_answers ~f:log_theory_answer;
   theory_answers
@@ -256,9 +256,9 @@ let verify_cmd =
     let doc = "Dataset $(docv) (CSV format only)." in
     Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE")
   in
-  let onnx_out_file =
+  let onnx_out_dir =
     let doc = "Path where to save the ONNX outputs from NIER." in
-    Arg.(value & opt (some string) None & info [ "onnx-out-file" ] ~doc)
+    Arg.(value & opt (some string) None & info [ "onnx-out-dir" ] ~doc)
   in
   let define_constants =
     let doc = "Define a declared constant with the given value." in
@@ -299,17 +299,17 @@ let verify_cmd =
   in
   let verify_term =
     let verify format loadpath memlimit timelimit prover prover_altern dataset
-      def_constants theories goals onnx_out_file files () =
+      def_constants theories goals onnx_out_dir files () =
       ignore
         (verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover
-           ?prover_altern ~def_constants ~theories ~goals ?onnx_out_file files)
+           ?prover_altern ~def_constants ~theories ~goals ?onnx_out_dir files)
     in
     Term.(
       const (fun _ -> exec_cmd cmdname)
       $ setup_logs
       $ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover
        $ prover_altern $ dataset $ define_constants $ theories $ goals
-       $ onnx_out_file $ files))
+       $ onnx_out_dir $ files))
   in
   Cmd.v info verify_term
 
diff --git a/src/verification.ml b/src/verification.ml
index 96f41de..ce4aacc 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -90,18 +90,26 @@ let create_env loadpath =
       (loadpath @ stdlib @ Whyconf.loadpath (Whyconf.get_main config)),
     config )
 
-let save_onnx_in_task onnx_out_file =
-  match onnx_out_file with
+let save_onnx_in_task onnx_out_dir =
+  match onnx_out_dir with
   | Some f ->
+    let i = ref 0 in
     Language.iter_nn (fun ls nn ->
       match nn.nn_nier with
       | Some g -> (
         try
-          Onnx.write g f;
-          Logging.info (fun m -> m "@[Wrote ONNX file at '%s'@]" f)
+          let () =
+            if not (Stdlib.Sys.file_exists f)
+            then Stdlib.Sys.mkdir f 0o755
+            else ()
+          in
+          let filename = Stdlib.Filename.concat f (Int.to_string !i ^ ".onnx") in
+          Onnx.write g filename;
+          Logging.info (fun m -> m "@[Wrote ONNX file at '%s'@]" filename);
+          i := !i + 1
         with Sys_error msg ->
           Logging.user_error (fun m ->
-            m "@[Tried to write ONNX file at '%s', got '%s'@]" f msg))
+            m "@[Tried to write ONNX file in folder '%s', got '%s'@]" f msg))
       | None ->
         Logging.warn (fun m ->
           m
@@ -350,7 +358,7 @@ let tasks_of_theory ~theories ~goals theory =
         List.exists goals_of_theory ~f:(String.equal task_goal_id))
 
 let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
-  ?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_file file =
+  ?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_dir file =
   let debug = Logging.(is_debug_level src_prover_call) in
   (if debug then Debug.(set_flag (lookup_flag "call_prover")));
   let env, config = create_env loadpath in
@@ -418,5 +426,5 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
           tasks)
       mstr_theory
   in
-  save_onnx_in_task onnx_out_file;
+  save_onnx_in_task onnx_out_dir;
   m
diff --git a/src/verification.mli b/src/verification.mli
index 7c836c2..91629cb 100644
--- a/src/verification.mli
+++ b/src/verification.mli
@@ -55,7 +55,7 @@ val verify :
   ?def_constants:(string * string) list ->
   ?theories:string list ->
   ?goals:(string * string list) list ->
-  ?onnx_out_file: string ->
+  ?onnx_out_dir: string ->
   File.t ->
   verification_result
 (** Starts a verification of the given [file] with the provided [prover].
@@ -75,8 +75,8 @@ val verify :
     @param goals
       is a theory:goals list each identifying only the goals of a theory to
       verify.
-    @param onnx_out_file
-      is the filepath where ONNX files generated from the NIER will be saved.
+    @param onnx_out_dir
+      is the folder where ONNX files generated from the NIER will be saved.
     @return
       for each theory, an [answer] for each goal of the theory, respecting the
       order of appearance. *)
diff --git a/tests/bin/inspect_onnx.py b/tests/bin/inspect_onnx.py
index 49d5dbd..816892d 100644
--- a/tests/bin/inspect_onnx.py
+++ b/tests/bin/inspect_onnx.py
@@ -1,4 +1,11 @@
 import onnx
-m = onnx.load('out.onnx')
-initi=m.graph.initializer[0]
-print("Initializer name {name} has data {val:3.3f}".format(name=initi.name,val=initi.float_data[2]))
+import os
+
+for file in os.listdir("out"):
+    m = onnx.load(os.path.join("out", file))
+    initi = m.graph.initializer[0]
+    print(
+        "Initializer name {name} has data {val:3.3f}".format(
+            name=initi.name, val=initi.float_data[2]
+        )
+    )
diff --git a/tests/nier_to_onnx.t b/tests/nier_to_onnx.t
index 8cd9fb3..f80e4e3 100644
--- a/tests/nier_to_onnx.t
+++ b/tests/nier_to_onnx.t
@@ -4,7 +4,7 @@ Test verify
   $ bin/pyrat.py --version
   PyRAT 1.1
 
-  $ caisar verify --format whyml --prover=PyRAT -v --onnx-out-file="out.onnx" - 2>&1 <<EOF | ./filter_tmpdir.sh
+  $ caisar verify --format whyml --prover=PyRAT -v --onnx-out-dir="out" - 2>&1 <<EOF | ./filter_tmpdir.sh
   > theory NIER_to_ONNX
   >   use ieee_float.Float64
   >   use bool.Bool
@@ -21,7 +21,7 @@ Test verify
   >       (0.5:t) .< (nn@@i)[0] .< (0.5:t)
   > end
   > EOF
-  [INFO] Wrote ONNX file at 'out.onnx'
+  [INFO] Wrote ONNX file at 'out/0.onnx'
   [INFO] Verification results for theory 'NIER_to_ONNX'
   Goal G: Unknown ()
 
-- 
GitLab