diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index 43e33697c2cb193a66ab6b4c492271023661c6e9..fc198456515c1fb290c209b4422cfe1b4d3e6ba7 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -423,7 +423,7 @@ module Node = struct
     | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ]
     | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ]
     | Squeeze { data; _ } -> [ data ]
-    | Reshape { input; _ } -> [ input ]
+    | Reshape { input; shape; _ } -> [ input; shape ]
     | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> []
 
   let map f n =
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index a8e807bef55837698b00e70d6b0ae6653282fd62..6594884bd152e0937b9f8d616c54beeee3eccaaf 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -55,7 +55,7 @@ let create_new_nn env input_vars outputs : string =
   let rec convert_old_nn (old_nn : Language.nn) old_index old_nn_args :
     IR.GFloat.Node.t =
     let old_index = Why3.Number.to_small_integer old_index in
-    let old_nn =
+    let old_nn' =
       match Onnx.Simple.parse old_nn.nn_filename with
       | Error s ->
         Logging.code_error ~src (fun m ->
@@ -71,7 +71,7 @@ let create_new_nn env input_vars outputs : string =
           (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 })
       in
       let node =
-        if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn)
+        if Ir.Nier_simple.Shape.equal node.shape (IR.input_shape old_nn')
         then node
         else
           IR.Node.create
@@ -80,13 +80,22 @@ let create_new_nn env input_vars outputs : string =
                  input = node;
                  shape =
                    IR.Node.constant_int_array
-                     (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn));
+                     (Ir.Nier_simple.Shape.to_array (IR.input_shape old_nn'));
                  allowzero = 0;
                })
       in
       node
     in
-    let out = IR.Node.replace_input input (IR.output old_nn) in
+    let out = IR.Node.replace_input input (IR.output old_nn') in
+    let out =
+      IR.Node.create
+        (Reshape
+           {
+             input = out;
+             shape = IR.Node.constant_int_array [| old_nn.nn_nb_outputs |];
+             allowzero = 0;
+           })
+    in
     Ir.Nier_simple.Node.gather_int out old_index
   and convert_term term : IR.GFloat.Node.t =
     if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
@@ -170,8 +179,10 @@ let create_new_nn env input_vars outputs : string =
       assert (!r = o.new_index);
       convert_old_nn o.old_nn o.old_index o.old_nn_args)
   in
-  let outputs = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in
-  let nn = IR.create outputs in
+  let output = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in
+  assert (
+    IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |]));
+  let nn = IR.create output in
   let filename = Stdlib.Filename.temp_file "caisar" ".onnx" in
   Onnx.Simple.write nn filename;
   filename