From be0a54f95a126a7501e5b16875e6c20f6234cc95 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 10 Apr 2024 22:07:01 +0200 Subject: [PATCH] [ir] Revise API for creating a gather via matmul encoding. --- lib/ir/nier_simple.ml | 23 +++++++++++++---------- lib/ir/nier_simple.mli | 7 +++++-- src/transformations/native_nn_prover.ml | 2 +- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index cf8a5ee..69d6e9a 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -408,12 +408,6 @@ module Node = struct ty = compute_ty descr; } - let gather_int input i = - let indices = - create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) }) - in - create (Gather { input; indices; axis = 0 }) - let constant_int_array a = create (Constant { data = GenTensor.of_int_array a }) @@ -426,10 +420,10 @@ module Node = struct { input = node; shape = constant_int_array (Shape.to_array shape) }) let gather_int_as_matmul input i = - let input = + let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in - let selector = Array.create ~len:(Shape.size input.shape) Float.zero in + let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in Array.set selector i Float.one; let selector = GenTensor.Float @@ -438,8 +432,17 @@ module Node = struct (Bigarray.Array1.of_array Float64 C_layout selector)) in let input2 = create (Constant { data = selector }) in - let matmul = create (Matmul { input1 = input; input2 }) in - reshape (Shape.of_array [| 1 |]) matmul + let result = create (Matmul { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + + let gather_int ?(encode = true) input i = + if encode + then gather_int_as_matmul input i + else + let indices = + create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) }) + in + create (Gather { input; indices; axis = 0 }) let concat_0 = function | [ n ] -> n diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index 7c7da27..c2a8bcf 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -148,8 +148,11 @@ module Node : sig include Base.Comparator.S with type t := t val create : descr -> t - val gather_int : t -> int -> t - val gather_int_as_matmul : t -> int -> t + + val gather_int : ?encode:bool -> t -> int -> t + (** create a node by selection at a given index. *) + (* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a + parameter, rather depend on prover. *) val constant_int_array : int array -> t (** create a node for a constant array *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index c40e112..d672f2b 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string = let get_input = Why3.Term.Hls.memo 10 (fun ls -> let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in - Ir.Nier_simple.Node.gather_int_as_matmul input i) + Ir.Nier_simple.Node.gather_int input i) in let cache = Why3.Term.Hterm.create 17 in (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed -- GitLab