From bdc669b66eb44f930bb3ff55c95b687060c36b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Fri, 5 Apr 2024 16:04:27 +0200 Subject: [PATCH] [Ir] add smart constructors and use them --- lib/ir/nier_simple.ml | 13 ++++++++++++ lib/ir/nier_simple.mli | 6 ++++++ src/transformations/native_nn_prover.ml | 27 +++++-------------------- tests/acasxu.t | 6 +++--- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index 3a774a1..c30b962 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -417,6 +417,19 @@ module Node = struct let constant_int_array a = create (Constant { data = GenTensor.of_int_array a }) + let reshape shape node = + if Shape.equal node.shape shape + then node + else + create + (Reshape + { input = node; shape = constant_int_array (Shape.to_array shape) }) + + let concat_0 = function + | [ n ] -> n + | [] -> failwith "empty concat" + | inputs -> create (Concat { inputs; axis = 0 }) + let preds node = match node.descr with | Constant _ | Input _ -> [] diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index a500fa9..eba29db 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -154,6 +154,12 @@ module Node : sig val constant_int_array : int array -> node (** create a node for a constant array *) + val reshape: Shape.t -> t -> t + (** create if necessary a reshape node *) + + val concat_0: t list -> t + (** create if necessary a concat node for the first axis *) + val map : (node -> node) -> node -> node (** [map f n] replace the direct inputs [i] of n by [f i] *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index bde96b6..5d731d7 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -79,30 +79,13 @@ let create_new_nn env input_vars outputs : string = IR.Node.create (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) in - let node = - if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn_nier) - then node - else - IR.Node.create - (Reshape - { - input = node; - shape = - IR.Node.constant_int_array - (Ir.Nier_simple.Shape.to_array - (IR.input_shape old_nn_nier)); - }) - in - node + Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node in let out = IR.Node.replace_input input (IR.output old_nn_nier) in let out = - IR.Node.create - (Reshape - { - input = out; - shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |]; - }) + Ir.Nier_simple.Node.reshape + (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |]) + out in Ir.Nier_simple.Node.gather_int out old_index and convert_term term : IR.GFloat.Node.t = @@ -194,7 +177,7 @@ let create_new_nn env input_vars outputs : string = assert (i = j); n) in - let output = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in + let output = IR.Node.concat_0 outputs in assert ( IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |])); let nn = IR.create output in diff --git a/tests/acasxu.t b/tests/acasxu.t index 3f6f158..5011942 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -1187,9 +1187,9 @@ Test verify on acasxu caisar_0.onnx has 1 input nodes {'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} caisar_1.onnx has 1 input nodes - {'name': '135', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + {'name': '134', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} caisar_2.onnx has 1 input nodes - {'name': '299', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '7'}]}}}} + {'name': '298', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '7'}]}}}} caisar_3.onnx has 1 input nodes - {'name': '468', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + {'name': '467', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} 4 files checked -- GitLab