diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index 43e33697c2cb193a66ab6b4c492271023661c6e9..fc198456515c1fb290c209b4422cfe1b4d3e6ba7 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -423,7 +423,7 @@ module Node = struct | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ] | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ] | Squeeze { data; _ } -> [ data ] - | Reshape { input; _ } -> [ input ] + | Reshape { input; shape; _ } -> [ input; shape ] | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> [] let map f n = diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index a8e807bef55837698b00e70d6b0ae6653282fd62..6594884bd152e0937b9f8d616c54beeee3eccaaf 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -55,7 +55,7 @@ let create_new_nn env input_vars outputs : string = let rec convert_old_nn (old_nn : Language.nn) old_index old_nn_args : IR.GFloat.Node.t = let old_index = Why3.Number.to_small_integer old_index in - let old_nn = + let old_nn' = match Onnx.Simple.parse old_nn.nn_filename with | Error s -> Logging.code_error ~src (fun m -> @@ -71,7 +71,7 @@ let create_new_nn env input_vars outputs : string = (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) in let node = - if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn) + if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn') then node else IR.Node.create @@ -80,13 +80,22 @@ let create_new_nn env input_vars outputs : string = input = node; shape = IR.Node.constant_int_array - (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn)); + (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn')); allowzero = 0; }) in node in - let out = IR.Node.replace_input input (IR.output old_nn) in + let out = IR.Node.replace_input input (IR.output old_nn') in + let out = + IR.Node.create + (Reshape + { + input = out; + shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |]; + allowzero = 0; + }) + in Ir.Nier_simple.Node.gather_int out old_index and convert_term term : IR.GFloat.Node.t = if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty) @@ -170,8 +179,10 @@ let create_new_nn env input_vars outputs : string = assert (!r = o.new_index); convert_old_nn o.old_nn o.old_index o.old_nn_args) in - let outputs = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in - let nn = IR.create outputs in + let output = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in + assert ( + IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |])); + let nn = IR.create output in let filename = Stdlib.Filename.temp_file "caisar" ".onnx" in Onnx.Simple.write nn filename; filename