From 14de28dc898d65f5b61fd93d47076590c6f4e79c Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Thu, 4 Apr 2024 15:26:54 +0200 Subject: [PATCH] [tests] Same output value from two different NN calls. --- src/transformations/native_nn_prover.ml | 13 ++++++------- tests/acasxu.t | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 409c41f..a8e807b 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -66,26 +66,25 @@ let create_new_nn env input_vars outputs : string = | Ok { nier = Ok g; _ } -> g in let input () = - let o = + let node = IR.Node.create (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) in - let o = - if Ir.Nier_simple.Shape.equal o.shape (IR.input_shape old_nn) - then o + let node = + if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn) + then node else IR.Node.create (Reshape { - input = o; + input = node; shape = IR.Node.constant_int_array (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn)); allowzero = 0; }) in - - o + node in let out = IR.Node.replace_input input (IR.output old_nn) in Ir.Nier_simple.Node.gather_int out old_index diff --git a/tests/acasxu.t b/tests/acasxu.t index 2d563ff..6741acd 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -96,7 +96,7 @@ Test verify on acasxu > ensures { let o1, o2 = result in o1 .>= o2 /\ o2 .>= o1 } = > let i = normalize_input j in > let o1 = (nn@@i)[clear_of_conflict] in - > let o2 = (nn@@i)[clear_of_conflict] in + > let o2 = (nn@@i)[weak_left] in > o1, o2 > > (* goal P1: -- GitLab