From bdc669b66eb44f930bb3ff55c95b687060c36b5c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Fri, 5 Apr 2024 16:04:27 +0200
Subject: [PATCH] [Ir] add smart constructors and use them

---
 lib/ir/nier_simple.ml                   | 13 ++++++++++++
 lib/ir/nier_simple.mli                  |  6 ++++++
 src/transformations/native_nn_prover.ml | 27 +++++--------------------
 tests/acasxu.t                          |  6 +++---
 4 files changed, 27 insertions(+), 25 deletions(-)

diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index 3a774a1..c30b962 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -417,6 +417,19 @@ module Node = struct
   let constant_int_array a =
     create (Constant { data = GenTensor.of_int_array a })
 
+  let reshape shape node =
+    if Shape.equal node.shape shape
+    then node
+    else
+      create
+        (Reshape
+           { input = node; shape = constant_int_array (Shape.to_array shape) })
+
+  let concat_0 = function
+    | [ n ] -> n
+    | [] -> failwith "empty concat"
+    | inputs -> create (Concat { inputs; axis = 0 })
+
   let preds node =
     match node.descr with
     | Constant _ | Input _ -> []
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
index a500fa9..eba29db 100644
--- a/lib/ir/nier_simple.mli
+++ b/lib/ir/nier_simple.mli
@@ -154,6 +154,12 @@ module Node : sig
   val constant_int_array : int array -> node
   (** create a node for a constant array *)
 
+  val reshape: Shape.t -> t -> t
+  (** create if necessary a reshape node *)
+
+  val concat_0: t list -> t
+  (** create if necessary a concat node for the first axis *)
+  
   val map : (node -> node) -> node -> node
   (** [map f n] replace the direct inputs [i] of n by [f i] *)
 
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index bde96b6..5d731d7 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -79,30 +79,13 @@ let create_new_nn env input_vars outputs : string =
         IR.Node.create
           (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 })
       in
-      let node =
-        if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn_nier)
-        then node
-        else
-          IR.Node.create
-            (Reshape
-               {
-                 input = node;
-                 shape =
-                   IR.Node.constant_int_array
-                     (Ir.Nier_simple.Shape.to_array
-                        (IR.input_shape old_nn_nier));
-               })
-      in
-      node
+      Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node
     in
     let out = IR.Node.replace_input input (IR.output old_nn_nier) in
     let out =
-      IR.Node.create
-        (Reshape
-           {
-             input = out;
-             shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |];
-           })
+      Ir.Nier_simple.Node.reshape
+        (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |])
+        out
     in
     Ir.Nier_simple.Node.gather_int out old_index
   and convert_term term : IR.GFloat.Node.t =
@@ -194,7 +177,7 @@ let create_new_nn env input_vars outputs : string =
          assert (i = j);
          n)
   in
-  let output = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in
+  let output = IR.Node.concat_0 outputs in
   assert (
     IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |]));
   let nn = IR.create output in
diff --git a/tests/acasxu.t b/tests/acasxu.t
index 3f6f158..5011942 100644
--- a/tests/acasxu.t
+++ b/tests/acasxu.t
@@ -1187,9 +1187,9 @@ Test verify on acasxu
   caisar_0.onnx has 1 input nodes
   {'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   caisar_1.onnx has 1 input nodes
-  {'name': '135', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
+  {'name': '134', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   caisar_2.onnx has 1 input nodes
-  {'name': '299', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '7'}]}}}}
+  {'name': '298', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '7'}]}}}}
   caisar_3.onnx has 1 input nodes
-  {'name': '468', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
+  {'name': '467', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   4 files checked
-- 
GitLab