From e4fd53c328ef59f062d5a55c89db3883d9151dce Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 25 May 2023 16:10:50 +0200
Subject: [PATCH] [language] Rework neural networks interface.

---
 src/interpretation.ml                   |  4 ++--
 src/language.ml                         | 16 ++++++++--------
 src/language.mli                        |  8 ++++----
 src/transformations/native_nn_prover.ml |  4 ++--
 4 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index b0a68d5..e5dd6c9 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -253,8 +253,8 @@ let caisar_builtins : caisar_env CRE.built_in_theories list =
         let filename = Caml.Filename.concat cwd neural_network in
         let nn =
           match id_string with
-          | "NNet" -> NNet (Language.create_nnet_nn env filename)
-          | "ONNX" -> ONNX (Language.create_onnx_nn env filename)
+          | "NNet" -> NNet (Language.create_nn_nnet env filename)
+          | "ONNX" -> ONNX (Language.create_nn_onnx env filename)
           | _ ->
             failwith (Fmt.str "Unrecognized neural network format %s" id_string)
         in
diff --git a/src/language.ml b/src/language.ml
index d525b1f..39afa6c 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -186,8 +186,8 @@ let mem_vector = Term.Hls.mem vectors
 (* -- Classifier *)
 
 type nn = {
-  nn_inputs : int;
-  nn_outputs : int;
+  nn_nb_inputs : int;
+  nn_nb_outputs : int;
   nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
   nn_filename : string;
   nn_nier : Onnx.G.t option; [@opaque]
@@ -204,7 +204,7 @@ let fresh_nn_ls env name =
   let id = Ident.id_fresh name in
   Term.create_fsymbol id [] ty
 
-let create_nnet_nn =
+let create_nn_nnet =
   Env.Wenv.memoize 13 (fun env ->
     let h = Hashtbl.create (module String) in
     let ty_elt =
@@ -219,8 +219,8 @@ let create_nnet_nn =
         | Error s -> Loc.errorm "%s" s
         | Ok { n_inputs; n_outputs; _ } ->
           {
-            nn_inputs = n_inputs;
-            nn_outputs = n_outputs;
+            nn_nb_inputs = n_inputs;
+            nn_nb_outputs = n_outputs;
             nn_ty_elt = ty_elt;
             nn_filename = filename;
             nn_nier = None;
@@ -229,7 +229,7 @@ let create_nnet_nn =
       Term.Hls.add nets ls nn;
       ls))
 
-let create_onnx_nn =
+let create_nn_onnx =
   Env.Wenv.memoize 13 (fun env ->
     let h = Hashtbl.create (module String) in
     let ty_elt = vector_elt_ty env in
@@ -249,8 +249,8 @@ let create_onnx_nn =
             | Ok nier -> Some nier
           in
           {
-            nn_inputs = n_inputs;
-            nn_outputs = n_outputs;
+            nn_nb_inputs = n_inputs;
+            nn_nb_outputs = n_outputs;
             nn_ty_elt = ty_elt;
             nn_filename = filename;
             nn_nier = nier;
diff --git a/src/language.mli b/src/language.mli
index 3804657..579255e 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -72,15 +72,15 @@ val mem_vector : Term.lsymbol -> bool
 (** -- Neural Network *)
 
 type nn = private {
-  nn_inputs : int;
-  nn_outputs : int;
+  nn_nb_inputs : int;
+  nn_nb_outputs : int;
   nn_ty_elt : Ty.ty;
   nn_filename : string;
   nn_nier : Onnx.G.t option;
 }
 [@@deriving show]
 
-val create_nnet_nn : Env.env -> string -> Term.lsymbol
-val create_onnx_nn : Env.env -> string -> Term.lsymbol
+val create_nn_nnet : Env.env -> string -> Term.lsymbol
+val create_nn_onnx : Env.env -> string -> Term.lsymbol
 val lookup_nn : Term.lsymbol -> nn option
 val mem_nn : Term.lsymbol -> bool
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 8a3816f..41b4568 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -37,8 +37,8 @@ let get_input_variables =
           [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
       when String.equal ls_name.id_string (Ident.op_infix "@@") -> (
       match (Language.lookup_nn ls1, Language.lookup_vector ls2) with
-      | Some { nn_inputs; _ }, Some n ->
-        assert (nn_inputs = n && n = List.length args);
+      | Some { nn_nb_inputs; _ }, Some n ->
+        assert (nn_nb_inputs = n && n = List.length args);
         List.foldi ~init:acc ~f:add args
       | _ -> acc)
     | _ -> Term.t_fold aux acc term
-- 
GitLab