From 5cd8729430f35fd2e05c92911654a7645f62deb5 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Wed, 21 Sep 2022 16:48:20 +0200
Subject: [PATCH] Added a NIER to the nnshape type.

---
 src/language.ml                         | 11 +++++++----
 src/language.mli                        |  1 +
 src/transformations/actual_net_apply.ml | 10 ++++------
 3 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/src/language.ml b/src/language.ml
index ca792ee..6cae58d 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -32,6 +32,7 @@ type nn_shape = {
   nb_outputs : int;
   ty_data : Ty.ty;
   filename : string;
+  nier : Onnx.G.t option;
 }
 
 type svm_shape = { nb_inputs : int; nb_classes : int; filename : string }
@@ -41,7 +42,7 @@ let loaded_svms = Term.Hls.create 10
 let lookup_loaded_nets = Term.Hls.find_opt loaded_nets
 let lookup_loaded_svms = Term.Hls.find_opt loaded_svms
 
-let register_nn_as_tuple nb_inputs nb_outputs filename env =
+let register_nn_as_tuple nb_inputs nb_outputs filename nier env =
   let net = Pmodule.read_module env [ "caisar" ] "NN" in
   let input_type =
     Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) []
@@ -57,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env =
       (Ty.ty_tuple (List.init nb_outputs ~f))
   in
   Term.Hls.add loaded_nets ls_net_apply
-    { filename; nb_inputs; nb_outputs; ty_data = input_type };
+  { filename; nb_inputs; nb_outputs; ty_data = input_type; nier };
   let th_uc =
     Pmodule.add_pdecl ~vc:false th_uc
       (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply))
@@ -86,13 +87,15 @@ let nnet_parser env _ filename _ =
   let model = Nnet.parse filename in
   match model with
   | Error s -> Loc.errorm "%s" s
-  | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env
+  | Ok model ->
+    register_nn_as_tuple model.n_inputs model.n_outputs filename None env
 
 let onnx_parser env _ filename _ =
   let model = Onnx.parse filename in
   match model with
   | Error s -> Loc.errorm "%s" s
-  | Ok (model,_nier) -> register_nn_as_tuple model.n_inputs model.n_outputs filename env
+  | Ok (model, nier) ->
+    register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env
 
 let ovo_parser env _ filename _ =
   let model = Ovo.parse filename in
diff --git a/src/language.mli b/src/language.mli
index 4d0a1ca..90c8cc3 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -27,6 +27,7 @@ type nn_shape = {
   nb_outputs : int;
   ty_data : Ty.ty;
   filename : string;
+  nier : Onnx.G.t option;
 }
 
 type svm_shape = { nb_inputs : int; nb_classes : int; filename : string }
diff --git a/src/transformations/actual_net_apply.ml b/src/transformations/actual_net_apply.ml
index d1acce8..d188b5c 100644
--- a/src/transformations/actual_net_apply.ml
+++ b/src/transformations/actual_net_apply.ml
@@ -346,14 +346,12 @@ let actual_nn_flow env =
       match Language.lookup_loaded_nets ls with
       | None -> Term.t_map aux term
       | Some nn ->
-        let nn_file = Unix.realpath nn.filename in
-        let ty_inputs = nn.ty_data in
         let g =
-          let p = Onnx.parse nn_file in
-          match p with
-          | Error s -> Loc.errorm "%s" s
-          | Ok (_model, nier) -> nier
+          match nn.nier with
+          | Some g -> g
+          | None -> failwith "Error, call this transform only on an ONNX NN."
         in
+        let ty_inputs = nn.ty_data in
         let cfg_term =
           terms_of_nier g ty_inputs env
             (Term.t_var @@ create_var "dummy" 0 ty_inputs vars)
-- 
GitLab