From 6247428896ad4c2a825440982526427cb87afa68 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Thu, 4 Apr 2024 16:40:31 +0200 Subject: [PATCH] [ir] Remove allowzero attribute from Reshape. --- lib/ir/nier_simple.ml | 3 +-- lib/ir/nier_simple.mli | 2 +- lib/onnx/simple.ml | 3 +-- src/transformations/native_nn_prover.ml | 11 +++++------ 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index fc19845..3ea2dcf 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -167,7 +167,6 @@ type descr = | Reshape of { input : node; shape : node; (* data int64 *) - allowzero : int; } | Flatten of { input : node; @@ -364,7 +363,7 @@ module Node = struct (check_matmul_size_ab ~a_sh:(Shape.to_list (compute_shape input1)) ~b_sh:(Shape.to_list (compute_shape input2))) - | Reshape { input; shape; allowzero = _ } -> + | Reshape { input; shape; _ } -> let shape = match shape.descr with | Constant { data = Int64 a } -> diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index 62d7c88..11b2d78 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -90,7 +90,6 @@ type descr = | Reshape of { input : node; shape : node; (* int64 *) - allowzero : int; } | Flatten of { input : node; @@ -145,6 +144,7 @@ module Node : sig val create : descr -> node val gather_int : node -> int -> node + val constant_int_array : int array -> node (** create a node for a constant array *) diff --git a/lib/onnx/simple.ml b/lib/onnx/simple.ml index 86c230e..ccaa32f 100644 --- a/lib/onnx/simple.ml +++ b/lib/onnx/simple.ml @@ -420,8 +420,7 @@ let nier_simple_to_onnx_protoc (nier_simple : Ir.Nier_simple.GFloat.t) = | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ -> failwith (Fmt.str "Not implemented export: %a" Ir.Nier_simple.Node.pp v) | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ] - | Reshape { allowzero; _ } -> - make "Reshape" [ mk_int "allowzero" allowzero ] + | Reshape _ -> make "Reshape" [] | Constant { data } -> let data = convert_into_tensor data in make "Constant" [ mk_tensor "value" data ] diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 6594884..bfda7ff 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -55,7 +55,7 @@ let create_new_nn env input_vars outputs : string = let rec convert_old_nn (old_nn : Language.nn) old_index old_nn_args : IR.GFloat.Node.t = let old_index = Why3.Number.to_small_integer old_index in - let old_nn' = + let old_nn_nier = match Onnx.Simple.parse old_nn.nn_filename with | Error s -> Logging.code_error ~src (fun m -> @@ -71,7 +71,7 @@ let create_new_nn env input_vars outputs : string = (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) in let node = - if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn') + if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn_nier) then node else IR.Node.create @@ -80,20 +80,19 @@ let create_new_nn env input_vars outputs : string = input = node; shape = IR.Node.constant_int_array - (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn')); - allowzero = 0; + (Ir.Nier_simple.Shape.to_array + (IR.input_shape old_nn_nier)); }) in node in - let out = IR.Node.replace_input input (IR.output old_nn') in + let out = IR.Node.replace_input input (IR.output old_nn_nier) in let out = IR.Node.create (Reshape { input = out; shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |]; - allowzero = 0; }) in Ir.Nier_simple.Node.gather_int out old_index -- GitLab