Skip to content
Snippets Groups Projects
Commit be0a54f9 authored by Michele Alberti's avatar Michele Alberti Committed by François Bobot
Browse files

[ir] Revise API for creating a gather via matmul encoding.

parent 4b0a91c1
No related branches found
No related tags found
No related merge requests found
...@@ -408,12 +408,6 @@ module Node = struct ...@@ -408,12 +408,6 @@ module Node = struct
ty = compute_ty descr; ty = compute_ty descr;
} }
let gather_int input i =
let indices =
create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
in
create (Gather { input; indices; axis = 0 })
let constant_int_array a = let constant_int_array a =
create (Constant { data = GenTensor.of_int_array a }) create (Constant { data = GenTensor.of_int_array a })
...@@ -426,10 +420,10 @@ module Node = struct ...@@ -426,10 +420,10 @@ module Node = struct
{ input = node; shape = constant_int_array (Shape.to_array shape) }) { input = node; shape = constant_int_array (Shape.to_array shape) })
let gather_int_as_matmul input i = let gather_int_as_matmul input i =
let input = let input1 =
reshape (Shape.of_array [| 1; Shape.size input.shape |]) input reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
in in
let selector = Array.create ~len:(Shape.size input.shape) Float.zero in let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in
Array.set selector i Float.one; Array.set selector i Float.one;
let selector = let selector =
GenTensor.Float GenTensor.Float
...@@ -438,8 +432,17 @@ module Node = struct ...@@ -438,8 +432,17 @@ module Node = struct
(Bigarray.Array1.of_array Float64 C_layout selector)) (Bigarray.Array1.of_array Float64 C_layout selector))
in in
let input2 = create (Constant { data = selector }) in let input2 = create (Constant { data = selector }) in
let matmul = create (Matmul { input1 = input; input2 }) in let result = create (Matmul { input1; input2 }) in
reshape (Shape.of_array [| 1 |]) matmul reshape (Shape.of_array [| 1 |]) result
let gather_int ?(encode = true) input i =
if encode
then gather_int_as_matmul input i
else
let indices =
create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
in
create (Gather { input; indices; axis = 0 })
let concat_0 = function let concat_0 = function
| [ n ] -> n | [ n ] -> n
......
...@@ -148,8 +148,11 @@ module Node : sig ...@@ -148,8 +148,11 @@ module Node : sig
include Base.Comparator.S with type t := t include Base.Comparator.S with type t := t
val create : descr -> t val create : descr -> t
val gather_int : t -> int -> t
val gather_int_as_matmul : t -> int -> t val gather_int : ?encode:bool -> t -> int -> t
(** create a node by selection at a given index. *)
(* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a
parameter, rather depend on prover. *)
val constant_int_array : int array -> t val constant_int_array : int array -> t
(** create a node for a constant array *) (** create a node for a constant array *)
......
...@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string = ...@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
let get_input = let get_input =
Why3.Term.Hls.memo 10 (fun ls -> Why3.Term.Hls.memo 10 (fun ls ->
let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
Ir.Nier_simple.Node.gather_int_as_matmul input i) Ir.Nier_simple.Node.gather_int input i)
in in
let cache = Why3.Term.Hterm.create 17 in let cache = Why3.Term.Hterm.create 17 in
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment