From be0a54f95a126a7501e5b16875e6c20f6234cc95 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 10 Apr 2024 22:07:01 +0200
Subject: [PATCH] [ir] Revise API for creating a gather via matmul encoding.

---
 lib/ir/nier_simple.ml                   | 23 +++++++++++++----------
 lib/ir/nier_simple.mli                  |  7 +++++--
 src/transformations/native_nn_prover.ml |  2 +-
 3 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
index cf8a5ee..69d6e9a 100644
--- a/lib/ir/nier_simple.ml
+++ b/lib/ir/nier_simple.ml
@@ -408,12 +408,6 @@ module Node = struct
         ty = compute_ty descr;
       }
 
-  let gather_int input i =
-    let indices =
-      create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
-    in
-    create (Gather { input; indices; axis = 0 })
-
   let constant_int_array a =
     create (Constant { data = GenTensor.of_int_array a })
 
@@ -426,10 +420,10 @@ module Node = struct
            { input = node; shape = constant_int_array (Shape.to_array shape) })
 
   let gather_int_as_matmul input i =
-    let input =
+    let input1 =
       reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
     in
-    let selector = Array.create ~len:(Shape.size input.shape) Float.zero in
+    let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in
     Array.set selector i Float.one;
     let selector =
       GenTensor.Float
@@ -438,8 +432,17 @@ module Node = struct
            (Bigarray.Array1.of_array Float64 C_layout selector))
     in
     let input2 = create (Constant { data = selector }) in
-    let matmul = create (Matmul { input1 = input; input2 }) in
-    reshape (Shape.of_array [| 1 |]) matmul
+    let result = create (Matmul { input1; input2 }) in
+    reshape (Shape.of_array [| 1 |]) result
+
+  let gather_int ?(encode = true) input i =
+    if encode
+    then gather_int_as_matmul input i
+    else
+      let indices =
+        create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
+      in
+      create (Gather { input; indices; axis = 0 })
 
   let concat_0 = function
     | [ n ] -> n
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
index 7c7da27..c2a8bcf 100644
--- a/lib/ir/nier_simple.mli
+++ b/lib/ir/nier_simple.mli
@@ -148,8 +148,11 @@ module Node : sig
   include Base.Comparator.S with type t := t
 
   val create : descr -> t
-  val gather_int : t -> int -> t
-  val gather_int_as_matmul : t -> int -> t
+
+  val gather_int : ?encode:bool -> t -> int -> t
+  (** create a node by selection at a given index. *)
+  (* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a
+     parameter, rather depend on prover. *)
 
   val constant_int_array : int array -> t
   (** create a node for a constant array *)
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index c40e112..d672f2b 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
   let get_input =
     Why3.Term.Hls.memo 10 (fun ls ->
       let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
-      Ir.Nier_simple.Node.gather_int_as_matmul input i)
+      Ir.Nier_simple.Node.gather_int input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
   (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
-- 
GitLab