From b0002177e57a63f2c6353b47bcaf5a53e50f9603 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Thu, 30 Nov 2023 11:03:57 +0100
Subject: [PATCH] Basic structure for ONNX output through command line.

---
 lib/onnx/onnx.ml     | 45 ++++++++++++++++++++++++++++++++++++++++++++
 lib/onnx/onnx.mli    |  5 ++++-
 src/language.ml      |  3 ++-
 src/main.ml          | 15 ++++++++++-----
 src/verification.ml  | 42 ++++++++++++++++++++++++++++++++++-------
 src/verification.mli |  1 +
 6 files changed, 97 insertions(+), 14 deletions(-)

diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml
index 355aa43..a3b631e 100644
--- a/lib/onnx/onnx.ml
+++ b/lib/onnx/onnx.ml
@@ -504,6 +504,46 @@ let nier_of_onnx_protoc (model : Oprotom.t) =
   | Some g -> produce_cfg g
   | None -> raise (ParseError "No graph in ONNX input file found")
 
+let nier_to_onnx_protoc nier =
+  (* TODO: write a simple ONNX model from a dummy NIER *)
+  let vertices = G.vertex_list nier in
+  let protocs =
+    let vertex_to_protoc v =
+      let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in
+      let name = NCFG.Node.get_name v in
+      let domain = "" in
+      let input, output =
+        (NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v)
+      in
+      Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain
+        ~attribute:[] ~doc_string:"" ()
+    in
+    List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices
+  in
+  let protog =
+    Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[]
+      ~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[]
+      ~quantization_annotation:[] ()
+  in
+  let protom =
+    Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[]
+      ~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:""
+      ~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[]
+      ~training_info:[] ~functions:[] ()
+  in
+  let writer = Oprotom.to_proto protom in
+  Ocaml_protoc_plugin.Writer.contents writer
+
+let write_nier_to_onnx _nier out_channel =
+  let nier = G.init_cfg in
+  let n =
+    Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP
+      ~op_p:None ~pred:[] ~succ:[] ~tensor:None
+  in
+  G.add_vertex nier n;
+  let onnx = nier_to_onnx_protoc nier in
+  Stdio.Out_channel.output_string out_channel onnx
+
 let parse_in_channel in_channel =
   let open Result in
   try
@@ -528,3 +568,8 @@ let parse filename =
   Fun.protect
     ~finally:(fun () -> Stdlib.close_in in_channel)
     (fun () -> parse_in_channel in_channel)
+
+let write nier filename =
+  let out_chan = Stdlib.open_out filename in
+  write_nier_to_onnx nier out_chan;
+  Stdlib.close_out out_chan
diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli
index e8d102c..0011ac1 100644
--- a/lib/onnx/onnx.mli
+++ b/lib/onnx/onnx.mli
@@ -30,4 +30,7 @@ type t = private {
 (** ONNX model metadata and intermediate representation. *)
 
 val parse : string -> (t, string) Result.t
-(** Parse an ONNX file. *)
+(** Parse an ONNX file into a NIER. *)
+
+val write : G.t -> string -> unit
+(** Write a NIER into an ONNX file. *)
diff --git a/src/language.ml b/src/language.ml
index 6d9a186..09c4813 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -244,7 +244,8 @@ let create_nn_onnx =
               Logs.warn (fun m ->
                 m "Cannot build network intermediate representation:@ %s" msg);
               None
-            | Ok nier -> Some nier
+            | Ok nier ->
+              Some nier
           in
           {
             nn_nb_inputs = n_inputs;
diff --git a/src/main.ml b/src/main.ml
index 73bcadc..e1b74da 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -123,14 +123,14 @@ let log_theory_answer =
           additional_info)))
 
 let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
-  ?def_constants ?theories ?goals files =
+  ?def_constants ?theories ?goals ?onnx_out_file files =
   let memlimit = Option.map memlimit ~f:memlimit_of_string in
   let timelimit = Option.map timelimit ~f:timelimit_of_string in
   let theory_answers =
     List.map files
       ~f:
         (Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset
-           prover ?prover_altern ?def_constants ?theories ?goals)
+           prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_file)
   in
   List.iter theory_answers ~f:log_theory_answer;
   theory_answers
@@ -256,6 +256,10 @@ let verify_cmd =
     let doc = "Dataset $(docv) (CSV format only)." in
     Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE")
   in
+  let onnx_out_file =
+    let doc = "Path where to save the ONNX outputs from NIER." in
+    Arg.(value & opt (some string) None & info [ "onnx-out-file" ] ~doc)
+  in
   let define_constants =
     let doc = "Define a declared constant with the given value." in
     Arg.(
@@ -295,16 +299,17 @@ let verify_cmd =
   in
   let verify_term =
     let verify format loadpath memlimit timelimit prover prover_altern dataset
-      def_constants theories goals files () =
+      def_constants theories goals onnx_out_file files () =
       ignore
         (verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover
-           ?prover_altern ~def_constants ~theories ~goals files)
+           ?prover_altern ~def_constants ~theories ~goals ?onnx_out_file files)
     in
     Term.(
       const (fun _ -> exec_cmd cmdname)
       $ setup_logs
       $ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover
-       $ prover_altern $ dataset $ define_constants $ theories $ goals $ files))
+       $ prover_altern $ dataset $ define_constants $ theories $ goals
+       $ onnx_out_file $ files))
   in
   Cmd.v info verify_term
 
diff --git a/src/verification.ml b/src/verification.ml
index e1ba5bf..e59c1bc 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -243,8 +243,8 @@ let answer_dataset limit config env prover config_prover driver dataset task =
   in
   (prover_answer, additional_info)
 
-let answer_generic limit config prover config_prover driver ~proof_strategy task
-    =
+let answer_generic limit config prover config_prover driver ~proof_strategy
+  onnx_out_file task =
   let tasks = proof_strategy task in
   let answers =
     List.concat_map tasks ~f:(fun task ->
@@ -255,6 +255,33 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
         | Some _ -> assert false (* By construction of the meta. *)
         | None -> invalid_arg "No neural network model found in task"
       in
+      let () =
+        let save_onnx_in ls f =
+          match Language.lookup_nn ls with
+          | Some { nn_nier = Some g; _ } -> (
+            try
+              Onnx.write g f;
+              Logs.info (fun m -> m "@[Wrote ONNX file at '%s'@]" f)
+            with Sys_error msg ->
+              Logs.err (fun m ->
+                m "@[System error: tried to write ONNX file a '%s', got '%s'@]"
+                  f msg))
+          | None -> ()
+          | _ -> ()
+        in
+        match onnx_out_file with
+        | Some f ->
+          Task.task_iter
+            (fun decl ->
+              match decl.td_node with
+              | Use _ | Clone _ | Meta _ -> ()
+              | Decl decl -> (
+                match decl.d_node with
+                | Dparam ls -> save_onnx_in ls f
+                | _ -> ()))
+            task
+        | _ -> ()
+      in
       let tasks =
         (* Turn [task] into a list (ie, conjunction) of disjunctions of
            tasks. *)
@@ -271,7 +298,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
   (prover_answer, additional_info)
 
 let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
-  def_constants task =
+  def_constants onnx_out_file task =
   let prover_answer, additional_info =
     match prover with
     | Prover.Saver -> answer_saver limit config env config_prover dataset task
@@ -283,12 +310,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
       let task = Interpretation.interpret_task ~cwd env ~def_constants task in
       let proof_strategy = Proof_strategy.apply_native_nn_prover in
       answer_generic limit config prover config_prover driver ~proof_strategy
-        task
+        onnx_out_file task
     | CVC5 ->
       let task = Interpretation.interpret_task ~cwd env ~def_constants task in
       let proof_strategy = Proof_strategy.apply_classic_prover env in
+      (* Not outputting ONNX for CVC5 as it does not handle such format*)
       answer_generic limit config prover config_prover driver ~proof_strategy
-        task
+        None task
   in
   let id = Task.task_goal task in
   { id; prover_answer; additional_info }
@@ -328,7 +356,7 @@ let tasks_of_theory ~theories ~goals theory =
         List.exists goals_of_theory ~f:(String.equal task_goal_id))
 
 let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
-  ?(def_constants = []) ?(theories = []) ?(goals = []) file =
+  ?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_file file =
   let debug = Logging.(is_debug_level src_prover_call) in
   (if debug then Debug.(set_flag (lookup_flag "call_prover")));
   let env, config = create_env loadpath in
@@ -391,6 +419,6 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
       List.map
         ~f:
           (call_prover ~cwd ~limit main env prover config_prover driver ?dataset
-             def_constants)
+             def_constants onnx_out_file)
         tasks)
     mstr_theory
diff --git a/src/verification.mli b/src/verification.mli
index 92bbcd3..87c9043 100644
--- a/src/verification.mli
+++ b/src/verification.mli
@@ -55,6 +55,7 @@ val verify :
   ?def_constants:(string * string) list ->
   ?theories:string list ->
   ?goals:(string * string list) list ->
+  ?onnx_out_file: string ->
   File.t ->
   verification_result
 (** Starts a verification of the given [file] with the provided [prover].
-- 
GitLab