Skip to content
Snippets Groups Projects
Commit fff06e7e authored by Michele Alberti's avatar Michele Alberti Committed by François Bobot
Browse files

[ir] Encode gather via matmul.

parent f674765a
No related branches found
No related tags found
No related merge requests found
......@@ -425,6 +425,22 @@ module Node = struct
(Reshape
{ input = node; shape = constant_int_array (Shape.to_array shape) })
let gather_int_as_matmul input i =
let input =
reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
in
let selector = Array.create ~len:(Shape.size input.shape) Float.zero in
Array.set selector i Float.one;
let selector =
GenTensor.Float
(Tensor.of_array1
(Shape.of_array [| Array.length selector; 1 |])
(Bigarray.Array1.of_array Float64 C_layout selector))
in
let input2 = create (Constant { data = selector }) in
let matmul = create (Matmul { input1 = input; input2 }) in
reshape (Shape.of_array [| 1 |]) matmul
let concat_0 = function
| [ n ] -> n
| [] -> failwith "empty concat"
......
......@@ -143,15 +143,16 @@ and ty =
module Node : sig
type t = node [@@deriving show]
val equal : node -> node -> bool
val equal : t -> t -> bool
include Base.Hashtbl.Key.S with type t := t
include Base.Comparator.S with type t := t
val create : descr -> node
val gather_int : node -> int -> node
val create : descr -> t
val gather_int : t -> int -> t
val gather_int_as_matmul : t -> int -> t
val constant_int_array : int array -> node
val constant_int_array : int array -> t
(** create a node for a constant array *)
val reshape : Shape.t -> t -> t
......@@ -160,24 +161,24 @@ module Node : sig
val concat_0 : t list -> t
(** create if necessary a concat node for the first axis *)
val map : (node -> node) -> node -> node
val map : (t -> t) -> t -> t
(** [map f n] replace the direct inputs [i] of n by [f i] *)
val map_rec : (node -> node) -> node -> node
val map_rec : (t -> t) -> t -> t
(** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by
[f i] *)
val replace_input : (unit -> node) -> node -> node
val replace_input : (unit -> t) -> t -> t
(** [replace_input f n] replace the input in [n] by [f ()] *)
val preds : node -> node list
val preds : t -> t list
(** Direct predecessors of a node *)
val iter_rec : (t -> unit) -> t -> unit
(** Iterate on the predecessors of a node and itself. Repect topological
order. *)
val compute_shape : node -> Shape.t
val compute_shape : t -> Shape.t
end
type t
......
......@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
let get_input =
Why3.Term.Hls.memo 10 (fun ls ->
let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
Ir.Nier_simple.Node.gather_int input i)
Ir.Nier_simple.Node.gather_int_as_matmul input i)
in
let cache = Why3.Term.Hterm.create 17 in
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
......
......@@ -899,9 +899,9 @@ Test verify on acasxu
caisar_0.onnx has 1 input nodes
{'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
caisar_1.onnx has 1 input nodes
{'name': '134', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
{'name': '154', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
caisar_2.onnx has 1 input nodes
{'name': '298', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
{'name': '338', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
caisar_3.onnx has 1 input nodes
{'name': '467', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
{'name': '531', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
4 files checked
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment