From 243b2994e19d31fe54aabafbecf2f3ba067169dc Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Tue, 30 Apr 2024 18:32:17 +0200 Subject: [PATCH] [Nnet] Nnet parser to Nir --- lib/nir/gentensor.ml | 22 +++++++++++++++++----- lib/nir/gentensor.mli | 12 +++++++++++- lib/nir/node.ml | 10 ++++++++-- lib/nnet/dune | 2 +- lib/nnet/nnet.ml | 43 +++++++++++++++++++++++++++++++++++++++++++ lib/nnet/nnet.mli | 3 +++ 6 files changed, 83 insertions(+), 9 deletions(-) diff --git a/lib/nir/gentensor.ml b/lib/nir/gentensor.ml index 5a2eae0..5f22aea 100644 --- a/lib/nir/gentensor.ml +++ b/lib/nir/gentensor.ml @@ -28,10 +28,22 @@ type t = let create_1_float f = Float (Tensor.create_1_float f) let create_1_int64 i = Int64 (Tensor.create_1_int64 i) -let of_int_array a = - Int64 - (Tensor.of_array1 - (Shape.of_array [| Array.length a |]) - (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int))) +let of_int64_array ?shape t = + let sh = + match shape with + | Some sh -> sh + | None -> Shape.of_array [| Array.length t |] + in + let a = Bigarray.Array1.of_array Bigarray.Int64 Bigarray.c_layout t in + Int64 (Tensor.of_array1 sh a) + +let of_float_array ?shape t = + let sh = + match shape with + | Some sh -> sh + | None -> Shape.of_array [| Array.length t |] + in + let a = Bigarray.Array1.of_array Bigarray.Float64 Bigarray.c_layout t in + Float (Tensor.of_array1 sh a) let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i diff --git a/lib/nir/gentensor.mli b/lib/nir/gentensor.mli index 9c9e1c3..ae60260 100644 --- a/lib/nir/gentensor.mli +++ b/lib/nir/gentensor.mli @@ -26,5 +26,15 @@ type t = val create_1_float : float -> t val create_1_int64 : int64 -> t -val of_int_array : int array -> t + +val of_float_array : ?shape:Shape.t -> float array -> t +(** [of_float_array a shape] returns a Tensor with data contained in [l] and + shape [shape]. If no shape is given, the resulting tensor shape is + 1-dimensional, equals to the length of [l].*) + +val of_int64_array : ?shape:Shape.t -> int64 array -> t +(** [of_int64_array a shape] returns a Tensor with data contained in [l] and + shape [shape]. If no shape is given, the resulting tensor shape is + 1-dimensional, equals to the length of [l].*) + val shape : t -> Shape.t diff --git a/lib/nir/node.ml b/lib/nir/node.ml index 656d928..36d6ff1 100644 --- a/lib/nir/node.ml +++ b/lib/nir/node.ml @@ -313,7 +313,8 @@ let create = Int.incr c; { id = !c; descr; shape = compute_shape_descr descr; ty = compute_ty descr } -let constant_int_array a = create (Constant { data = Gentensor.of_int_array a }) +let constant_int_array a = + create (Constant { data = Gentensor.of_int64_array a }) let reshape shape node = if Shape.equal node.shape shape @@ -321,7 +322,12 @@ let reshape shape node = else create (Reshape - { input = node; shape = constant_int_array (Shape.to_array shape) }) + { + input = node; + shape = + constant_int_array + (Array.map ~f:Int64.of_int @@ Shape.to_array shape); + }) let gather_int_as_matmul input i = let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in diff --git a/lib/nnet/dune b/lib/nnet/dune index b6180d7..3825ea5 100644 --- a/lib/nnet/dune +++ b/lib/nnet/dune @@ -1,5 +1,5 @@ (library (name nnet) (public_name caisar.nnet) - (libraries base csv caisar_logging) + (libraries base csv caisar.nir caisar_logging) (synopsis "NNet parser for CAISAR")) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index 16b237b..969593c 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -193,6 +193,49 @@ let parse_in_channel ?(permissive = false) filename in_channel = | Sys_error s -> Error s | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) +let to_nir t = + let open Nir in + let create_input_node in_shape = Node.create (Input { shape = in_shape }) in + let rec traverse_wb wb acc = + match wb with + (* Recursively traverse weights and biases. Builds the necessary nodes and + return the last node of a simple neural network consisting of Matmul, Add + and ReLU. *) + | [] -> create_input_node acc + | weights_biases -> ( + match weights_biases with + | [] -> failwith "Empty list" + | _ :: [] -> failwith "Weights or biases missing." + | weights :: biases :: rest -> + (* recursion will happen in the creation of the input1 node to the + current node *) + let input_node = traverse_wb rest acc in + let weights_tensor = + Nir.Gentensor.of_float_array (Array.of_list weights) + in + let weights_node = + Node.create (Node.Constant { data = weights_tensor }) + in + let matmul_node = + Node.create + (Node.Matmul { input1 = input_node; input2 = weights_node }) + in + let biases_tensor = + Nir.Gentensor.of_float_array (Array.of_list biases) + in + let biases_node = + Node.create (Node.Constant { data = biases_tensor }) + in + let add_node = + Node.create (Add { input1 = matmul_node; input2 = biases_node }) + in + let relu_node = Node.create (Node.ReLu { input = add_node }) in + relu_node) + in + let w = t.weights_biases and in_sh = Shape.of_list [ t.n_inputs ] in + let g = Nir.Ngraph.create (traverse_wb w in_sh) in + g + let parse ?(permissive = false) filename = let in_channel = Stdlib.open_in filename in Fun.protect diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli index 61d58ef..ca35fdb 100644 --- a/lib/nnet/nnet.mli +++ b/lib/nnet/nnet.mli @@ -42,3 +42,6 @@ val parse : ?permissive:bool -> string -> (t, string) Result.t @param permissive [false] by default. When set, parsing does not fail on non available information, which are set to [None] instead. *) + +val to_nir : t -> Nir.Ngraph.t +(** Convert an well-formed NNet into a Nir. *) -- GitLab