From 4a14fc9d90dca935d21315edda85f970ffcdc332 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Thu, 11 Apr 2024 12:44:21 +0200
Subject: [PATCH] [NN_native] create only one node by identical nn application

---
 src/transformations/native_nn_prover.ml | 25 ++++++++++++++++++++-----
 tests/acasxu.t                          |  4 ++--
 2 files changed, 22 insertions(+), 7 deletions(-)

diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 9a5f1b8..ceeabb2 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -57,10 +57,22 @@ let create_new_nn env input_vars outputs : string =
       Ir.Nier_simple.Node.gather_int input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
+  let nn_cache = Stdlib.Hashtbl.create 17 in
   (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
      into nodes. *)
-  let rec convert_old_nn old_nn old_index old_nn_args =
-    let old_index = Why3.Number.to_small_integer old_index in
+  let rec convert_old_nn old_nn old_nn_args =
+    let converted_args = List.map ~f:convert_term old_nn_args in
+    let id =
+      ( old_nn.Language.nn_filename,
+        List.map converted_args ~f:(fun n -> n.Ir.Nier_simple.id) )
+    in
+    match Stdlib.Hashtbl.find_opt nn_cache id with
+    | None ->
+      let n = convert_old_nn_aux old_nn converted_args in
+      Stdlib.Hashtbl.add nn_cache id n;
+      n
+    | Some n -> n
+  and convert_old_nn_aux old_nn converted_args =
     let old_nn_nier =
       match Onnx.Simple.parse old_nn.Language.nn_filename with
       | Error s ->
@@ -75,8 +87,7 @@ let create_new_nn env input_vars outputs : string =
     let input () =
       (* Regroup the terms into one node *)
       let node =
-        IR.Node.create
-          (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 })
+        IR.Node.create (Concat { inputs = converted_args; axis = 0 })
       in
       Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node
     in
@@ -86,6 +97,10 @@ let create_new_nn env input_vars outputs : string =
         (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |])
         out
     in
+    out
+  and convert_old_nn_at_old_index old_nn old_index old_nn_args =
+    let out = convert_old_nn old_nn old_nn_args in
+    let old_index = Why3.Number.to_small_integer old_index in
     Ir.Nier_simple.Node.gather_int out old_index
   and convert_term term =
     match Why3.Term.Hterm.find_opt cache term with
@@ -148,7 +163,7 @@ let create_new_nn env input_vars outputs : string =
       match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with
       | Some ({ nn_nb_inputs; _ } as old_nn), Some vector_length ->
         assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
-        convert_old_nn old_nn old_index tl
+        convert_old_nn_at_old_index old_nn old_index tl
       | _, _ ->
         Logging.code_error ~src (fun m ->
           m "Neural network application without fixed NN or arguments: %a"
diff --git a/tests/acasxu.t b/tests/acasxu.t
index f7fa6bd..2c22b3c 100644
--- a/tests/acasxu.t
+++ b/tests/acasxu.t
@@ -901,7 +901,7 @@ Test verify on acasxu
   caisar_1.onnx has 1 input nodes
   {'name': '179', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   caisar_2.onnx has 1 input nodes
-  {'name': '392', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
+  {'name': '327', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   caisar_3.onnx has 1 input nodes
-  {'name': '619', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
+  {'name': '554', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   4 files checked
-- 
GitLab