diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index c30b962f7cab660fa4c8386e6a5ce54372611c1c..cf8a5eefba0e61d7879a49c96e74d54ae995c449 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -425,6 +425,22 @@ module Node = struct (Reshape { input = node; shape = constant_int_array (Shape.to_array shape) }) + let gather_int_as_matmul input i = + let input = + reshape (Shape.of_array [| 1; Shape.size input.shape |]) input + in + let selector = Array.create ~len:(Shape.size input.shape) Float.zero in + Array.set selector i Float.one; + let selector = + GenTensor.Float + (Tensor.of_array1 + (Shape.of_array [| Array.length selector; 1 |]) + (Bigarray.Array1.of_array Float64 C_layout selector)) + in + let input2 = create (Constant { data = selector }) in + let matmul = create (Matmul { input1 = input; input2 }) in + reshape (Shape.of_array [| 1 |]) matmul + let concat_0 = function | [ n ] -> n | [] -> failwith "empty concat" diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index 1824367fc08517c6c75f0f77ae80c8365ff50ded..ee7e2cc415fa7aa6292c26968e72aa9a7ec3cea0 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -143,15 +143,16 @@ and ty = module Node : sig type t = node [@@deriving show] - val equal : node -> node -> bool + val equal : t -> t -> bool include Base.Hashtbl.Key.S with type t := t include Base.Comparator.S with type t := t - val create : descr -> node - val gather_int : node -> int -> node + val create : descr -> t + val gather_int : t -> int -> t + val gather_int_as_matmul : t -> int -> t - val constant_int_array : int array -> node + val constant_int_array : int array -> t (** create a node for a constant array *) val reshape : Shape.t -> t -> t @@ -160,24 +161,24 @@ module Node : sig val concat_0 : t list -> t (** create if necessary a concat node for the first axis *) - val map : (node -> node) -> node -> node + val map : (t -> t) -> t -> t (** [map f n] replace the direct inputs [i] of n by [f i] *) - val map_rec : (node -> node) -> node -> node + val map_rec : (t -> t) -> t -> t (** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by [f i] *) - val replace_input : (unit -> node) -> node -> node + val replace_input : (unit -> t) -> t -> t (** [replace_input f n] replace the input in [n] by [f ()] *) - val preds : node -> node list + val preds : t -> t list (** Direct predecessors of a node *) val iter_rec : (t -> unit) -> t -> unit (** Iterate on the predecessors of a node and itself. Repect topological order. *) - val compute_shape : node -> Shape.t + val compute_shape : t -> Shape.t end type t diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 5acb447cc4cc74deece6be9a05e1066fd43de519..e6c1459ab4bca3b9d242fc0632c9e2044e4ff244 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string = let get_input = Why3.Term.Hls.memo 10 (fun ls -> let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in - Ir.Nier_simple.Node.gather_int input i) + Ir.Nier_simple.Node.gather_int_as_matmul input i) in let cache = Why3.Term.Hterm.create 17 in (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed diff --git a/tests/acasxu.t b/tests/acasxu.t index 6747505c0fc0c3c1bd2cdd068d7df802efa454ac..8256f664151c8af3c1c91fc60b7974007226c75a 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -899,9 +899,9 @@ Test verify on acasxu caisar_0.onnx has 1 input nodes {'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} caisar_1.onnx has 1 input nodes - {'name': '134', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} + {'name': '154', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} caisar_2.onnx has 1 input nodes - {'name': '298', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + {'name': '338', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} caisar_3.onnx has 1 input nodes - {'name': '467', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} + {'name': '531', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} 4 files checked