From 6247428896ad4c2a825440982526427cb87afa68 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 4 Apr 2024 16:40:31 +0200
Subject: [PATCH] [ir] Remove allowzero attribute from Reshape.

---
 lib/ir/nier_simple.ml                   |  3 +--
 lib/ir/nier_simple.mli                  |  2 +-
 lib/onnx/simple.ml                      |  3 +--
 src/transformations/native_nn_prover.ml | 11 +++++------
 4 files changed, 8 insertions(+), 11 deletions(-)

diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index fc19845..3ea2dcf 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -167,7 +167,6 @@ type descr =
   | Reshape of {
       input : node;
       shape : node; (* data int64 *)
-      allowzero : int;
     }
   | Flatten of {
       input : node;
@@ -364,7 +363,7 @@ module Node = struct
         (check_matmul_size_ab
            ~a_sh:(Shape.to_list (compute_shape input1))
            ~b_sh:(Shape.to_list (compute_shape input2)))
-    | Reshape { input; shape; allowzero = _ } ->
+    | Reshape { input; shape; _ } ->
       let shape =
         match shape.descr with
         | Constant { data = Int64 a } ->
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
index 62d7c88..11b2d78 100644
--- a/lib/ir/nier_simple.mli
+++ b/lib/ir/nier_simple.mli
@@ -90,7 +90,6 @@ type descr =
   | Reshape of {
       input : node;
       shape : node; (* int64 *)
-      allowzero : int;
     }
   | Flatten of {
       input : node;
@@ -145,6 +144,7 @@ module Node : sig
 
   val create : descr -> node
   val gather_int : node -> int -> node
+
   val constant_int_array : int array -> node
   (** create a node for a constant array *)
 
diff --git a/lib/onnx/simple.ml b/lib/onnx/simple.ml
index 86c230e..ccaa32f 100644
--- a/lib/onnx/simple.ml
+++ b/lib/onnx/simple.ml
@@ -420,8 +420,7 @@ let nier_simple_to_onnx_protoc (nier_simple : Ir.Nier_simple.GFloat.t) =
       | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ ->
         failwith (Fmt.str "Not implemented export: %a" Ir.Nier_simple.Node.pp v)
       | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ]
-      | Reshape { allowzero; _ } ->
-        make "Reshape" [ mk_int "allowzero" allowzero ]
+      | Reshape _ -> make "Reshape" []
       | Constant { data } ->
         let data = convert_into_tensor data in
         make "Constant" [ mk_tensor "value" data ]
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 6594884..bfda7ff 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -55,7 +55,7 @@ let create_new_nn env input_vars outputs : string =
   let rec convert_old_nn (old_nn : Language.nn) old_index old_nn_args :
     IR.GFloat.Node.t =
     let old_index = Why3.Number.to_small_integer old_index in
-    let old_nn' =
+    let old_nn_nier =
       match Onnx.Simple.parse old_nn.nn_filename with
       | Error s ->
         Logging.code_error ~src (fun m ->
@@ -71,7 +71,7 @@ let create_new_nn env input_vars outputs : string =
           (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 })
       in
       let node =
-        if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn')
+        if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn_nier)
         then node
         else
           IR.Node.create
@@ -80,20 +80,19 @@ let create_new_nn env input_vars outputs : string =
                  input = node;
                  shape =
                    IR.Node.constant_int_array
-                     (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn'));
-                 allowzero = 0;
+                     (Ir.Nier_simple.Shape.to_array
+                        (IR.input_shape old_nn_nier));
                })
       in
       node
     in
-    let out = IR.Node.replace_input input (IR.output old_nn') in
+    let out = IR.Node.replace_input input (IR.output old_nn_nier) in
     let out =
       IR.Node.create
         (Reshape
            {
              input = out;
              shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |];
-             allowzero = 0;
            })
     in
     Ir.Nier_simple.Node.gather_int out old_index
-- 
GitLab