diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index cf8a5eefba0e61d7879a49c96e74d54ae995c449..69d6e9a382b9c9fc46e0089d4599ef59850626f2 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -408,12 +408,6 @@ module Node = struct
         ty = compute_ty descr;
       }
 
-  let gather_int input i =
-    let indices =
-      create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
-    in
-    create (Gather { input; indices; axis = 0 })
-
   let constant_int_array a =
     create (Constant { data = GenTensor.of_int_array a })
 
@@ -426,10 +420,10 @@ module Node = struct
            { input = node; shape = constant_int_array (Shape.to_array shape) })
 
   let gather_int_as_matmul input i =
-    let input =
+    let input1 =
       reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
     in
-    let selector = Array.create ~len:(Shape.size input.shape) Float.zero in
+    let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in
     Array.set selector i Float.one;
     let selector =
       GenTensor.Float
@@ -438,8 +432,17 @@ module Node = struct
            (Bigarray.Array1.of_array Float64 C_layout selector))
     in
     let input2 = create (Constant { data = selector }) in
-    let matmul = create (Matmul { input1 = input; input2 }) in
-    reshape (Shape.of_array [| 1 |]) matmul
+    let result = create (Matmul { input1; input2 }) in
+    reshape (Shape.of_array [| 1 |]) result
+
+  let gather_int ?(encode = true) input i =
+    if encode
+    then gather_int_as_matmul input i
+    else
+      let indices =
+        create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
+      in
+      create (Gather { input; indices; axis = 0 })
 
   let concat_0 = function
     | [ n ] -> n
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
index 7c7da27cf3630bfcd32db185f4d0fa10c8969a11..c2a8bcfbe8279cd7e6fb04d4eff3623860997ed1 100644
--- a/lib/ir/nier_simple.mli
+++ b/lib/ir/nier_simple.mli
@@ -148,8 +148,11 @@ module Node : sig
   include Base.Comparator.S with type t := t
 
   val create : descr -> t
-  val gather_int : t -> int -> t
-  val gather_int_as_matmul : t -> int -> t
+
+  val gather_int : ?encode:bool -> t -> int -> t
+  (** create a node by selection at a given index. *)
+  (* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a
+     parameter, rather depend on prover. *)
 
   val constant_int_array : int array -> t
   (** create a node for a constant array *)
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index c40e112ba75d8ad3885ce9f19877b5e9f8c9575e..d672f2be62250037882442789eeb11091ac71b0a 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
   let get_input =
     Why3.Term.Hls.memo 10 (fun ls ->
       let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
-      Ir.Nier_simple.Node.gather_int_as_matmul input i)
+      Ir.Nier_simple.Node.gather_int input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
   (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed