diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index c30b962f7cab660fa4c8386e6a5ce54372611c1c..cf8a5eefba0e61d7879a49c96e74d54ae995c449 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -425,6 +425,22 @@ module Node = struct
         (Reshape
            { input = node; shape = constant_int_array (Shape.to_array shape) })
 
+  let gather_int_as_matmul input i =
+    let input =
+      reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
+    in
+    let selector = Array.create ~len:(Shape.size input.shape) Float.zero in
+    Array.set selector i Float.one;
+    let selector =
+      GenTensor.Float
+        (Tensor.of_array1
+           (Shape.of_array [| Array.length selector; 1 |])
+           (Bigarray.Array1.of_array Float64 C_layout selector))
+    in
+    let input2 = create (Constant { data = selector }) in
+    let matmul = create (Matmul { input1 = input; input2 }) in
+    reshape (Shape.of_array [| 1 |]) matmul
+
   let concat_0 = function
     | [ n ] -> n
     | [] -> failwith "empty concat"
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
index 1824367fc08517c6c75f0f77ae80c8365ff50ded..ee7e2cc415fa7aa6292c26968e72aa9a7ec3cea0 100644
--- a/lib/ir/nier_simple.mli
+++ b/lib/ir/nier_simple.mli
@@ -143,15 +143,16 @@ and ty =
 module Node : sig
   type t = node [@@deriving show]
 
-  val equal : node -> node -> bool
+  val equal : t -> t -> bool
 
   include Base.Hashtbl.Key.S with type t := t
   include Base.Comparator.S with type t := t
 
-  val create : descr -> node
-  val gather_int : node -> int -> node
+  val create : descr -> t
+  val gather_int : t -> int -> t
+  val gather_int_as_matmul : t -> int -> t
 
-  val constant_int_array : int array -> node
+  val constant_int_array : int array -> t
   (** create a node for a constant array *)
 
   val reshape : Shape.t -> t -> t
@@ -160,24 +161,24 @@ module Node : sig
   val concat_0 : t list -> t
   (** create if necessary a concat node for the first axis *)
 
-  val map : (node -> node) -> node -> node
+  val map : (t -> t) -> t -> t
   (** [map f n] replace the direct inputs [i] of n by [f i] *)
 
-  val map_rec : (node -> node) -> node -> node
+  val map_rec : (t -> t) -> t -> t
   (** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by
       [f i] *)
 
-  val replace_input : (unit -> node) -> node -> node
+  val replace_input : (unit -> t) -> t -> t
   (** [replace_input f n] replace the input in [n] by [f ()] *)
 
-  val preds : node -> node list
+  val preds : t -> t list
   (** Direct predecessors of a node *)
 
   val iter_rec : (t -> unit) -> t -> unit
   (** Iterate on the predecessors of a node and itself. Repect topological
       order. *)
 
-  val compute_shape : node -> Shape.t
+  val compute_shape : t -> Shape.t
 end
 
 type t
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 5acb447cc4cc74deece6be9a05e1066fd43de519..e6c1459ab4bca3b9d242fc0632c9e2044e4ff244 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
   let get_input =
     Why3.Term.Hls.memo 10 (fun ls ->
       let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
-      Ir.Nier_simple.Node.gather_int input i)
+      Ir.Nier_simple.Node.gather_int_as_matmul input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
   (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
diff --git a/tests/acasxu.t b/tests/acasxu.t
index 6747505c0fc0c3c1bd2cdd068d7df802efa454ac..8256f664151c8af3c1c91fc60b7974007226c75a 100644
--- a/tests/acasxu.t
+++ b/tests/acasxu.t
@@ -899,9 +899,9 @@ Test verify on acasxu
   caisar_0.onnx has 1 input nodes
   {'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   caisar_1.onnx has 1 input nodes
-  {'name': '134', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
+  {'name': '154', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   caisar_2.onnx has 1 input nodes
-  {'name': '298', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
+  {'name': '338', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   caisar_3.onnx has 1 input nodes
-  {'name': '467', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
+  {'name': '531', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   4 files checked