Skip to content
Snippets Groups Projects
Commit 5a1be3d1 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[Nnet]Correct data handling

parent 8c482227
No related branches found
No related tags found
No related merge requests found
(library (library
(name nnet) (name nnet)
(public_name caisar.nnet) (public_name caisar.nnet)
(libraries base csv caisar.nir caisar_logging) (libraries base csv caisar.nir caisar_logging logs)
(synopsis "NNet parser for CAISAR")) (synopsis "NNet parser for CAISAR"))
...@@ -37,8 +37,81 @@ type t = { ...@@ -37,8 +37,81 @@ type t = {
mean_values : (float list * float) option; mean_values : (float list * float) option;
range_values : (float list * float) option; range_values : (float list * float) option;
weights_biases : float list list; weights_biases : float list list;
nir : Nir.Ngraph.t;
} }
let to_nir weights_biases n_inputs layer_sizes =
let open Nir in
let create_input_node in_shape = Node.create (Input { shape = in_shape }) in
(* weights_biases is a list of list describing the weight and biases of the NN.
* Each inner list contains either layer_size element (weight) or one element
* (bias) *)
let aggregated_wb (weights_biases : float list list) layer_sizes =
let rec slice l ~start ~stop =
assert (stop > start);
assert (stop < List.length weights_biases);
match l with
| [] -> failwith "Cannot take a slice from an empty list"
| h :: t ->
let tail =
if stop = 0 then [] else slice t ~start:(start - 1) ~stop:(stop - 1)
in
if start > 0 then tail else h :: tail
in
let rec aggregs full_wb l_sizes acc acc_idx prev_size =
match l_sizes with
| [] -> acc
| x :: y ->
(*TODO: a much more efficient approach would be to consume full_wb at
* the same time instead of slicing the full list everytime. *)
let w_idx = acc_idx + x in
let b_idx = w_idx + x in
let w = List.concat @@ slice full_wb ~start:acc_idx ~stop:(w_idx - 1)
and b = List.concat @@ slice full_wb ~start:w_idx ~stop:(b_idx - 1)
and sh = Shape.of_array [| prev_size; x |] in
aggregs full_wb y ((w, b, sh) :: acc) b_idx x
in
(* First element of layer_sizes is input size, skipping it *)
aggregs weights_biases (List.drop layer_sizes 1) [] 0
(List.nth_exn layer_sizes 0)
in
let rec traverse_wb wb acc =
match wb with
(* Recursively traverse weights and biases. Builds the necessary nodes and
return the last node of a simple neural network consisting of a Matmul,
Add and ReLU. *)
(* Expectations: wb is a list of size num_layer containing tuple whose first
element is the flattened list of weights and second element is the list
of biases for each layer. *)
| [] -> create_input_node acc
| (weights, biases, sh_w) :: rest ->
(* recursion will happen in the creation of the input1 node to the current
node *)
let input_node = traverse_wb rest acc in
let weights_tensor =
Nir.Gentensor.of_float_array ~shape:sh_w (Array.of_list weights)
in
let weights_node =
Node.create (Node.Constant { data = weights_tensor })
in
let matmul_node =
Node.create (Node.Matmul { input1 = input_node; input2 = weights_node })
in
let biases_tensor = Nir.Gentensor.of_float_array (Array.of_list biases) in
let biases_node = Node.create (Node.Constant { data = biases_tensor }) in
let add_node =
Node.create (Add { input1 = matmul_node; input2 = biases_node })
in
let relu_node = Node.create (Node.ReLu { input = add_node }) in
relu_node
in
let in_sh = Shape.of_list [ n_inputs ] in
let g =
Nir.Ngraph.create (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh)
in
g
(* NNet format handling. *) (* NNet format handling. *)
let nnet_format_error s = let nnet_format_error s =
...@@ -175,6 +248,7 @@ let parse_in_channel ?(permissive = false) filename in_channel = ...@@ -175,6 +248,7 @@ let parse_in_channel ?(permissive = false) filename in_channel =
ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values -> ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values ->
let weights_biases = handle_nnet_weights_and_biases in_channel in let weights_biases = handle_nnet_weights_and_biases in_channel in
Csv.close_in in_channel; Csv.close_in in_channel;
let nir = to_nir weights_biases n_is layer_sizes in
Ok Ok
{ {
n_layers = n_ls; n_layers = n_ls;
...@@ -187,6 +261,7 @@ let parse_in_channel ?(permissive = false) filename in_channel = ...@@ -187,6 +261,7 @@ let parse_in_channel ?(permissive = false) filename in_channel =
mean_values; mean_values;
range_values; range_values;
weights_biases; weights_biases;
nir;
} }
with with
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg | Csv.Failure (_nrecord, _nfield, msg) -> Error msg
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
(** Module to parse neural networks written in the NNet format
https://github.com/sisl/NNet *)
type t = private { type t = private {
n_layers : int; (** Number of layers. *) n_layers : int; (** Number of layers. *)
n_inputs : int; (** Number of inputs. *) n_inputs : int; (** Number of inputs. *)
...@@ -32,7 +35,9 @@ type t = private { ...@@ -32,7 +35,9 @@ type t = private {
(** Mean values of inputs and one value for all outputs. *) (** Mean values of inputs and one value for all outputs. *)
range_values : (float list * float) option; range_values : (float list * float) option;
(** Range values of inputs and one value for all outputs. *) (** Range values of inputs and one value for all outputs. *)
weights_biases : float list list; (** All weights and biases of NNet model. *) weights_biases : float list list;
(** All weights and biases of NNet model. *)
nir : Nir.Ngraph.t;
} }
(** NNet model metadata. *) (** NNet model metadata. *)
......
...@@ -204,7 +204,7 @@ type nn = { ...@@ -204,7 +204,7 @@ type nn = {
[@@deriving show] [@@deriving show]
and nn_format = and nn_format =
| NNet | NNet of Nir.Ngraph.t option [@printer fun fmt _ -> Fmt.pf fmt "<nir>"]
| ONNX of Nir.Ngraph.t option [@printer fun fmt _ -> Fmt.pf fmt "<nir>"] | ONNX of Nir.Ngraph.t option [@printer fun fmt _ -> Fmt.pf fmt "<nir>"]
[@@deriving show] [@@deriving show]
...@@ -226,13 +226,13 @@ let create_nn_nnet env filename = ...@@ -226,13 +226,13 @@ let create_nn_nnet env filename =
let model = Nnet.parse ~permissive:true filename in let model = Nnet.parse ~permissive:true filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } -> | Ok { n_inputs; n_outputs; nir; _ } ->
{ {
nn_nb_inputs = n_inputs; nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs; nn_nb_outputs = n_outputs;
nn_ty_elt = ty_float64_t env; nn_ty_elt = ty_float64_t env;
nn_filename = filename; nn_filename = filename;
nn_format = NNet; nn_format = NNet (Some nir);
} }
let create_nn_onnx env filename = let create_nn_onnx env filename =
......
...@@ -83,7 +83,7 @@ type nn = private { ...@@ -83,7 +83,7 @@ type nn = private {
[@@deriving show] [@@deriving show]
and nn_format = and nn_format =
| NNet | NNet of Nir.Ngraph.t option
| ONNX of Nir.Ngraph.t option | ONNX of Nir.Ngraph.t option
[@@deriving show] [@@deriving show]
......
...@@ -162,9 +162,9 @@ end) ...@@ -162,9 +162,9 @@ end)
let app_terms_of_nir_output m d (nn : Language.nn) env index tl = let app_terms_of_nir_output m d (nn : Language.nn) env index tl =
match nn.nn_format with match nn.nn_format with
| NNet -> Logging.not_implemented_yet (fun f -> f "NNet to SMT conversion") | ONNX None | NNet None ->
| ONNX None -> Logging.code_error ~src (fun f -> f "No ONNX to convert") Logging.code_error ~src (fun f -> f "No ONNX to convert")
| ONNX (Some g) -> | ONNX (Some g) | NNet (Some g) ->
let vtl = List.fold tl ~init:Why3.Term.Mvs.empty ~f:Why3.Term.t_freevars in let vtl = List.fold tl ~init:Why3.Term.Mvs.empty ~f:Why3.Term.t_freevars in
let m' = ref (MTermL.find_def (Map.empty (module Nir.Node)) tl !m) in let m' = ref (MTermL.find_def (Map.empty (module Nir.Node)) tl !m) in
let t = let t =
......
...@@ -93,7 +93,7 @@ let create_env loadpath = ...@@ -93,7 +93,7 @@ let create_env loadpath =
let write_nir_as_onnx onnx_out_dir = let write_nir_as_onnx onnx_out_dir =
Language.iter_nn (fun ls nn -> Language.iter_nn (fun ls nn ->
match nn.nn_format with match nn.nn_format with
| ONNX (Some nn_nir) -> ( | ONNX (Some nn_nir) | NNet (Some nn_nir) -> (
try try
if not (Stdlib.Sys.file_exists onnx_out_dir) if not (Stdlib.Sys.file_exists onnx_out_dir)
then Stdlib.Sys.mkdir onnx_out_dir 0o755; then Stdlib.Sys.mkdir onnx_out_dir 0o755;
......
...@@ -647,7 +647,7 @@ Test verify on acasxu ...@@ -647,7 +647,7 @@ Test verify on acasxu
(Interpreter_types.NNet, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5; (Interpreter_types.NNet, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5;
nn_ty_elt = t; nn_ty_elt = t;
nn_filename = "./TestNetwork.nnet"; nn_filename = "./TestNetwork.nnet";
nn_format = Language.NNet })) nn_format = <nir> }))
[DEBUG]{ProverSpec} Prover-tailored specification: [DEBUG]{ProverSpec} Prover-tailored specification:
-0.328421367053318091766556108268559910356998443603515625 <= x0 -0.328421367053318091766556108268559910356998443603515625 <= x0
x0 <= 0.67985927880386987087746319957659579813480377197265625 x0 <= 0.67985927880386987087746319957659579813480377197265625
...@@ -721,7 +721,7 @@ Test verify on acasxu ...@@ -721,7 +721,7 @@ Test verify on acasxu
(Interpreter_types.NNet, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5; (Interpreter_types.NNet, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5;
nn_ty_elt = t; nn_ty_elt = t;
nn_filename = "./TestNetwork.nnet"; nn_filename = "./TestNetwork.nnet";
nn_format = Language.NNet })) nn_format = <nir> }))
[DEBUG]{ProverSpec} Prover-tailored specification: [DEBUG]{ProverSpec} Prover-tailored specification:
-0.328421367053318091766556108268559910356998443603515625 <= x0 -0.328421367053318091766556108268559910356998443603515625 <= x0
x0 <= 0.67985927880386987087746319957659579813480377197265625 x0 <= 0.67985927880386987087746319957659579813480377197265625
......
...@@ -23,7 +23,7 @@ Test verify ...@@ -23,7 +23,7 @@ Test verify
[DEBUG]{NIR} Wrote NIR as ONNX in file 'out/nn_onnx.nir.onnx' [DEBUG]{NIR} Wrote NIR as ONNX in file 'out/nn_onnx.nir.onnx'
Goal G: Unknown () Goal G: Unknown ()
Data should be 0.135 Input name should be 0
$ python3 bin/inspect_onnx.py $ python3 bin/inspect_onnx.py
out/nn_onnx.nir.onnx has 1 input nodes out/nn_onnx.nir.onnx has 1 input nodes
...@@ -36,9 +36,9 @@ Data should be 0.135 ...@@ -36,9 +36,9 @@ Data should be 0.135
> use caisar.types.Vector > use caisar.types.Vector
> use caisar.model.Model > use caisar.model.Model
> use caisar.model.NN > use caisar.model.NN
> >
> constant nn: model nn = read_model "TestNetwork.nnet" > constant nn: model nn = read_model "TestNetwork.nnet"
> >
> goal G: > goal G:
> forall i: vector t. > forall i: vector t.
> has_length i 5 -> > has_length i 5 ->
...@@ -46,10 +46,10 @@ Data should be 0.135 ...@@ -46,10 +46,10 @@ Data should be 0.135
> (0.5:t) .< (nn @@ i)[0] .< (0.5:t) > (0.5:t) .< (nn @@ i)[0] .< (0.5:t)
> end > end
> EOF > EOF
[DEBUG]{NIER} Wrote NIER as ONNX in file 'out_nnet/nn_onnx.nir.onnx' [DEBUG]{NIR} Wrote NIR as ONNX in file 'out_nnet/nn_nnet.nir.onnx'
Goal G: Unknown () Goal G: Unknown ()
Data should be 0.135 Input name should be 0
$ python3 bin/inspect_onnx.py $ python3 bin/inspect_onnx.py
out/nn_onnx.nir.onnx has 1 input nodes out/nn_onnx.nir.onnx has 1 input nodes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment