diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 409c41fa4585bf52875bd73b7473deb58a413e35..a8e807bef55837698b00e70d6b0ae6653282fd62 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -66,26 +66,25 @@ let create_new_nn env input_vars outputs : string = | Ok { nier = Ok g; _ } -> g in let input () = - let o = + let node = IR.Node.create (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) in - let o = - if Ir.Nier_simple.Shape.equal o.shape (IR.input_shape old_nn) - then o + let node = + if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn) + then node else IR.Node.create (Reshape { - input = o; + input = node; shape = IR.Node.constant_int_array (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn)); allowzero = 0; }) in - - o + node in let out = IR.Node.replace_input input (IR.output old_nn) in Ir.Nier_simple.Node.gather_int out old_index diff --git a/tests/acasxu.t b/tests/acasxu.t index 2d563ff1c6cd3c283bb35e95c80a7d93e9fd726d..6741acdef5fc405ba6fe2615314b391a12f7d7eb 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -96,7 +96,7 @@ Test verify on acasxu > ensures { let o1, o2 = result in o1 .>= o2 /\ o2 .>= o1 } = > let i = normalize_input j in > let o1 = (nn@@i)[clear_of_conflict] in - > let o2 = (nn@@i)[clear_of_conflict] in + > let o2 = (nn@@i)[weak_left] in > o1, o2 > > (* goal P1: