diff --git a/lib/nir/gentensor.ml b/lib/nir/gentensor.ml index 5a2eae003d00794fab5a6b7283f1692a7a177f21..5f22aea5e9d843cec2f8a7a73254bc9d3b2060af 100644 --- a/lib/nir/gentensor.ml +++ b/lib/nir/gentensor.ml @@ -28,10 +28,22 @@ type t = let create_1_float f = Float (Tensor.create_1_float f) let create_1_int64 i = Int64 (Tensor.create_1_int64 i) -let of_int_array a = - Int64 - (Tensor.of_array1 - (Shape.of_array [| Array.length a |]) - (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int))) +let of_int64_array ?shape t = + let sh = + match shape with + | Some sh -> sh + | None -> Shape.of_array [| Array.length t |] + in + let a = Bigarray.Array1.of_array Bigarray.Int64 Bigarray.c_layout t in + Int64 (Tensor.of_array1 sh a) + +let of_float_array ?shape t = + let sh = + match shape with + | Some sh -> sh + | None -> Shape.of_array [| Array.length t |] + in + let a = Bigarray.Array1.of_array Bigarray.Float64 Bigarray.c_layout t in + Float (Tensor.of_array1 sh a) let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i diff --git a/lib/nir/gentensor.mli b/lib/nir/gentensor.mli index 9c9e1c31e4406063d4bb7767b9f2daa6742cee72..ae602600a4346e601470cdb0e00562274bc28ddc 100644 --- a/lib/nir/gentensor.mli +++ b/lib/nir/gentensor.mli @@ -26,5 +26,15 @@ type t = val create_1_float : float -> t val create_1_int64 : int64 -> t -val of_int_array : int array -> t + +val of_float_array : ?shape:Shape.t -> float array -> t +(** [of_float_array a shape] returns a Tensor with data contained in [l] and + shape [shape]. If no shape is given, the resulting tensor shape is + 1-dimensional, equals to the length of [l].*) + +val of_int64_array : ?shape:Shape.t -> int64 array -> t +(** [of_int64_array a shape] returns a Tensor with data contained in [l] and + shape [shape]. If no shape is given, the resulting tensor shape is + 1-dimensional, equals to the length of [l].*) + val shape : t -> Shape.t diff --git a/lib/nir/node.ml b/lib/nir/node.ml index 656d928f12fe52ed3a5d1d5c0b804622fb148b4b..36d6ff1beba8b958252661898e14a2808b1a2dd8 100644 --- a/lib/nir/node.ml +++ b/lib/nir/node.ml @@ -313,7 +313,8 @@ let create = Int.incr c; { id = !c; descr; shape = compute_shape_descr descr; ty = compute_ty descr } -let constant_int_array a = create (Constant { data = Gentensor.of_int_array a }) +let constant_int_array a = + create (Constant { data = Gentensor.of_int64_array a }) let reshape shape node = if Shape.equal node.shape shape @@ -321,7 +322,12 @@ let reshape shape node = else create (Reshape - { input = node; shape = constant_int_array (Shape.to_array shape) }) + { + input = node; + shape = + constant_int_array + (Array.map ~f:Int64.of_int @@ Shape.to_array shape); + }) let gather_int_as_matmul input i = let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in diff --git a/lib/nnet/dune b/lib/nnet/dune index b6180d7c1db4d993685be8a0962965b65928005e..3825ea5fc488d60a929a86e046e7417feffb298e 100644 --- a/lib/nnet/dune +++ b/lib/nnet/dune @@ -1,5 +1,5 @@ (library (name nnet) (public_name caisar.nnet) - (libraries base csv caisar_logging) + (libraries base csv caisar.nir caisar_logging) (synopsis "NNet parser for CAISAR")) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index 16b237b79521d393fd2165939085de4b58b72ba2..969593cae347b9c3a0f8691a25e80eda1c9c9291 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -193,6 +193,49 @@ let parse_in_channel ?(permissive = false) filename in_channel = | Sys_error s -> Error s | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) +let to_nir t = + let open Nir in + let create_input_node in_shape = Node.create (Input { shape = in_shape }) in + let rec traverse_wb wb acc = + match wb with + (* Recursively traverse weights and biases. Builds the necessary nodes and + return the last node of a simple neural network consisting of Matmul, Add + and ReLU. *) + | [] -> create_input_node acc + | weights_biases -> ( + match weights_biases with + | [] -> failwith "Empty list" + | _ :: [] -> failwith "Weights or biases missing." + | weights :: biases :: rest -> + (* recursion will happen in the creation of the input1 node to the + current node *) + let input_node = traverse_wb rest acc in + let weights_tensor = + Nir.Gentensor.of_float_array (Array.of_list weights) + in + let weights_node = + Node.create (Node.Constant { data = weights_tensor }) + in + let matmul_node = + Node.create + (Node.Matmul { input1 = input_node; input2 = weights_node }) + in + let biases_tensor = + Nir.Gentensor.of_float_array (Array.of_list biases) + in + let biases_node = + Node.create (Node.Constant { data = biases_tensor }) + in + let add_node = + Node.create (Add { input1 = matmul_node; input2 = biases_node }) + in + let relu_node = Node.create (Node.ReLu { input = add_node }) in + relu_node) + in + let w = t.weights_biases and in_sh = Shape.of_list [ t.n_inputs ] in + let g = Nir.Ngraph.create (traverse_wb w in_sh) in + g + let parse ?(permissive = false) filename = let in_channel = Stdlib.open_in filename in Fun.protect diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli index 61d58ef6cf393fc23c041171162cdc670e172f5d..ca35fdbaa5f536fec7421443dd0e408e6cc17ffb 100644 --- a/lib/nnet/nnet.mli +++ b/lib/nnet/nnet.mli @@ -42,3 +42,6 @@ val parse : ?permissive:bool -> string -> (t, string) Result.t @param permissive [false] by default. When set, parsing does not fail on non available information, which are set to [None] instead. *) + +val to_nir : t -> Nir.Ngraph.t +(** Convert an well-formed NNet into a Nir. *)