diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index cf8a5eefba0e61d7879a49c96e74d54ae995c449..69d6e9a382b9c9fc46e0089d4599ef59850626f2 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -408,12 +408,6 @@ module Node = struct ty = compute_ty descr; } - let gather_int input i = - let indices = - create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) }) - in - create (Gather { input; indices; axis = 0 }) - let constant_int_array a = create (Constant { data = GenTensor.of_int_array a }) @@ -426,10 +420,10 @@ module Node = struct { input = node; shape = constant_int_array (Shape.to_array shape) }) let gather_int_as_matmul input i = - let input = + let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in - let selector = Array.create ~len:(Shape.size input.shape) Float.zero in + let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in Array.set selector i Float.one; let selector = GenTensor.Float @@ -438,8 +432,17 @@ module Node = struct (Bigarray.Array1.of_array Float64 C_layout selector)) in let input2 = create (Constant { data = selector }) in - let matmul = create (Matmul { input1 = input; input2 }) in - reshape (Shape.of_array [| 1 |]) matmul + let result = create (Matmul { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + + let gather_int ?(encode = true) input i = + if encode + then gather_int_as_matmul input i + else + let indices = + create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) }) + in + create (Gather { input; indices; axis = 0 }) let concat_0 = function | [ n ] -> n diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index 7c7da27cf3630bfcd32db185f4d0fa10c8969a11..c2a8bcfbe8279cd7e6fb04d4eff3623860997ed1 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -148,8 +148,11 @@ module Node : sig include Base.Comparator.S with type t := t val create : descr -> t - val gather_int : t -> int -> t - val gather_int_as_matmul : t -> int -> t + + val gather_int : ?encode:bool -> t -> int -> t + (** create a node by selection at a given index. *) + (* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a + parameter, rather depend on prover. *) val constant_int_array : int array -> t (** create a node for a constant array *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index c40e112ba75d8ad3885ce9f19877b5e9f8c9575e..d672f2be62250037882442789eeb11091ac71b0a 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string = let get_input = Why3.Term.Hls.memo 10 (fun ls -> let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in - Ir.Nier_simple.Node.gather_int_as_matmul input i) + Ir.Nier_simple.Node.gather_int input i) in let cache = Why3.Term.Hterm.create 17 in (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed