From 06e8d80dc646a74b0237b57eeb0ef489327ea3fa Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Thu, 11 Mar 2021 21:00:43 +0100 Subject: [PATCH] Rework nnet model information retrieval, and use it for the nnet spec compatibility check. --- model.ml | 328 ++++++++++++++++++++++++++++-------------------------- model.mli | 48 ++++---- 2 files changed, 199 insertions(+), 177 deletions(-) diff --git a/model.ml b/model.ml index c31acd3..19ca382 100644 --- a/model.ml +++ b/model.ml @@ -4,171 +4,189 @@ (* *) (**************************************************************************) -(******************************) -(*This program verifies whether a file is a nnet file or not*) -(******************************) - - -(*This section verifies if the file has an nnet extension before verifying its content*) - -type format = Onnx | Nnet [@@deriving show { with_path = false}] +open Base +module Format = Caml.Format +module Sys = Caml.Sys +module Filename = Caml.Filename + +type nnet = { + n_layers : int; + n_inputs: int; + n_outputs: int; + max_layer_size: int; + layer_sizes: int list; + min_input_values: float list; + max_input_values: float list; + mean_values: float list * float; + range_values: float list * float; +} [@@deriving show { with_path = false }] + +type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}] type t = { format: format; filename: string; } -let is_nnet = ref 0;; + +(* Nnet format handling. *) + +let nnet_format_error s = + Error (Format.sprintf "Nnet format error: %s condition not satisfied." s) + +let nnet_delimiter = Str.regexp "," + +(* Parse a single Nnet format line: split line using [nnet_delimiter] as + delimiter, and convert each string into a number by means of converter [f]. *) +let handle_nnet_line ~f line = + List.map + ~f:(fun s -> f (String.strip s)) + (Str.split nnet_delimiter line) + +(* Skip the header part, ie comments, of the Nnet format. *) +let handle_nnet_header filename in_channel = + let exception End_of_header in + let pos_in = ref (Stdlib.pos_in in_channel) in + try + while true do + let line = Stdlib.input_line in_channel in + if not (Str.string_match (Str.regexp "//") line 0) + then raise End_of_header + else pos_in := Stdlib.pos_in in_channel + done; + assert false + with + | End_of_header -> + (* At this point we have read one line past the header part: seek back. *) + Stdlib.seek_in in_channel !pos_in; + Ok () + | End_of_file -> + Error (Format.sprintf "Nnet model not found in file `%s'." filename) + +(* Retrieve number of layers, inputs, outputs and maximum layer size. *) +let handle_nnet_basic_info in_channel = + try + let line = Stdlib.input_line in_channel in + match handle_nnet_line ~f:Stdlib.int_of_string line with + | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> + Ok (n_layers, n_inputs, n_outputs, max_layer_size) + | _ -> + nnet_format_error "second" + with End_of_file -> + nnet_format_error "second" + +(* Retrieve size of each layer, including inputs and outputs. *) +let handle_nnet_layer_sizes n_layers in_channel = + try + let line = Stdlib.input_line in_channel in + let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in + if List.length layer_sizes = (n_layers + 1) + then Ok layer_sizes + else nnet_format_error "third" + with End_of_file -> + nnet_format_error "third" + +(* Skip unused flag. *) +let handle_nnet_unused_flag in_channel = + try + let _ = Stdlib.input_line in_channel in + Ok () + with End_of_file -> + nnet_format_error "forth" + +(* Retrive minimum values of inputs. *) +let handle_nnet_min_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length min_input_values = n_inputs + then Ok min_input_values + else nnet_format_error "fifth" + with End_of_file -> + nnet_format_error "fifth" + +(* Retrive maximum values of inputs. *) +let handle_nnet_max_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length max_input_values = n_inputs + then Ok max_input_values + else nnet_format_error "sixth" + with End_of_file -> + nnet_format_error "sixth" + +(* Retrieve mean values of inputs and one value for all outputs. *) +let handle_nnet_mean_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length mean_values = (n_inputs + 1) + then + let mean_input_values, mean_output_value = + List.split_n mean_values n_inputs + in + Ok (mean_input_values, List.hd_exn mean_output_value) + else + nnet_format_error "seventh" + with End_of_file -> + nnet_format_error "seventh" + +(* Retrieve range values of inputs and one value for all outputs. *) +let handle_nnet_range_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length range_values = (n_inputs + 1) + then + let range_input_values, range_output_value = + List.split_n range_values n_inputs + in + Ok (range_input_values, List.hd_exn range_output_value) + else + nnet_format_error "eighth" + with End_of_file -> + nnet_format_error "eighth" + +(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see + https://github.com/sisl/NNet for details.) *) +let retrieve_nnet_metadata filename = + let open Result in + let in_channel = Stdlib.open_in filename in + handle_nnet_header filename in_channel >>= fun () -> + handle_nnet_basic_info in_channel >>= fun (n_ls, n_ins, n_outs, max_l_size) -> + handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> + handle_nnet_unused_flag in_channel >>= fun () -> + handle_nnet_min_input_values n_ins in_channel >>= fun min_input_values -> + handle_nnet_max_input_values n_ins in_channel >>= fun max_input_values -> + handle_nnet_mean_values n_ins in_channel >>= fun mean_values -> + handle_nnet_range_values n_ins in_channel >>= fun range_values -> + Stdlib.close_in in_channel; + Ok + { n_layers = n_ls; + n_inputs = n_ins; + n_outputs = n_outs; + max_layer_size = max_l_size; + layer_sizes; + min_input_values; + max_input_values; + mean_values; + range_values; } + + +(* Generic model. *) let build ~filename = + let open Result in if Sys.file_exists filename then - let format = - if Filename.check_suffix filename "nnet" - then begin - is_nnet := 1; - Nnet - end - else begin - is_nnet := 0; - print_endline "This is not an nnet file -extension"; - Onnx - end - in - print_newline(); + begin + if Filename.check_suffix filename "onnx" + then Ok Onnx + else + retrieve_nnet_metadata filename >>= fun nnet -> + Ok (Nnet nnet) + end >>= fun format -> Ok { format; filename } else Error (Format.sprintf "No such file `%s'." filename) -;; - -let flag = ref 0;; -(*The type "record" is created to collect some parameters of the nnet file *) -type record = { - number_of_layers : int; - number_of_inputs: int; - number_of_outputs:int; - maximum_layer_size:int - };; - - -(******************************) -(*This function helps verify whether some conditions are satisfied to determine -whether the file has a nnet format. -The conditions are : -- The file starts with header lines, it can be any number of lines so long as they begin with "//" -- The following line contains four values: Number of layers, number of inputs, number of outputs, and maximum layer size -- The number of inputs in that line (second element) is equal to that of the following line (first element) -- The flag line contains only one element -- The minimum values line contains as many elements as there are inputs -- The maximum values line contains as many elements as there are inputs -- The mean values line contains as many elements as there are inputs plus an extra value for all outputs -- The range values line contains as many elements as there are inputs plus an extra value for all outputs -*) -(******************************) - - -let content_file filename = - if !is_nnet = 1 then begin - let in_channel = open_in filename in - let c = ref 0 in - let flag1 = ref 0 in - let flag2 = ref 0 in - let flag3 = ref 0 in - let flag4 = ref 0 in - let flag5 = ref 0 in - let element1 = ref "" in - let element2 = ref "" in - let a = ref 0 in - let cnt = ref 0 in - try - while true do - let line = input_line in_channel in - cnt := !cnt+1; - (*Verifying the first condition*) - let cmp = Str.string_match (Str.regexp "//") line 0 in - if cmp then begin - flag := !flag +1; - a := !flag; - end - else - try - (*Verifying the second condition*) - let b = Str.split (Str.regexp ",") line in - let tmp= List.nth b 1 in - let num_layers = List.nth b 0 in - let num_outputs = List.nth b 2 in - let max_layer_size = List.nth b 3 in - element2 := tmp; - let record2 = {number_of_layers = int_of_string(num_layers); number_of_inputs=int_of_string !element2;number_of_outputs=int_of_string num_outputs;maximum_layer_size=int_of_string max_layer_size} in - Format.printf "Number of layers: %d\n" record2.number_of_layers; - Format.printf "Number of inputs: %d\n" record2.number_of_inputs; - Format.printf "Number of outputs: %d\n" record2.number_of_outputs; - Format.printf "Maximum layer size : %d\n" record2.maximum_layer_size ; - c := List.length b; - raise Exit; - with Exit -> - let line2 = input_line in_channel in - let d = Str.split (Str.regexp ",") line2 in - let tmp2 = List.nth d 0 in - element1 := tmp2; - try - (*Verifying the fourth condition using variable flag1*) - let line3 = input_line in_channel in - let e = Str.split (Str.regexp ",") line3 in - flag1 := List.length e; - raise Exit; - with Exit -> - print_newline(); - try - (*Verifying the fifth condition using variable flag2*) - let line4 = input_line in_channel in - let f = Str.split (Str.regexp ",") line4 in - flag2 := List.length f; - raise Exit; - with Exit -> - (*Verifying the sixth condition using variable flag3*) - let line5 = input_line in_channel in - let g = Str.split (Str.regexp ",") line5 in - flag3 := List.length g; - try - (*Verifying the seventh condition using variable flag4*) - let line6 = input_line in_channel in - let h = Str.split (Str.regexp ",") line6 in - flag4 := List.length h; - raise Exit; - with Exit -> - (*Verifying the eighth condition using variable flag5*) - let line7 = input_line in_channel in - let i = Str.split (Str.regexp ",") line7 in - flag5 := List.length i; - raise End_of_file; - done - with End_of_file -> - if !flag = 0 then print_endline "not an nnet file" - else - (*Verifying the third condition using the variable t*) - let t = compare element1 element2 in - (*Verifying all of the conditions using the if branch*) - if (!c = 4 && t=0 && !flag1 = 1 && flag2=ref(int_of_string !element2) && flag3=ref(int_of_string !element2) && flag4=ref((int_of_string !element2)+1) && flag5=ref((int_of_string !element2)+1)) then - print_endline "nnet file" - else print_endline "it is not an nnet file"; - close_in in_channel; - end -;; - -let my_fun filename = - build ~filename;; -let my_fun2 filename = - content_file filename; -;; - - -(*This command is used to create an interface with the user such that they can insert the name of the file -to be tested. -In order to compile the code, use : -ocamlfind ocamlopt -package str -o (name_of_the_executable) (name_of_the_program.ml) -linkpkg -To run the executable, use : -./(name_of_the_executable) (name_of_the_nnet_file.nnet) *) - -my_fun Sys.argv.(1);; -my_fun2 Sys.argv.(1);; diff --git a/model.mli b/model.mli index 26993da..fdc61c7 100644 --- a/model.mli +++ b/model.mli @@ -4,32 +4,36 @@ (* *) (**************************************************************************) -type format = Onnx | Nnet [@@deriving show { with_path = false }] +(** Nnet model metadata. *) +type nnet = private { + (** Number of layers. *) + n_layers : int; + (** Number of inputs. *) + n_inputs: int; + (** Number of outputs. *) + n_outputs: int; + (** Maximum layer size. *) + max_layer_size: int; + (** Size of each layer. *) + layer_sizes: int list; + (** Minimum values of inputs. *) + min_input_values: float list; + (** Maximum values of inputs. *) + max_input_values: float list; + (** Mean values of inputs and one value for all outputs. *) + mean_values: float list * float; + (** Range values of inputs and one value for all outputs. *) + range_values: float list * float; +} [@@deriving show { with_path = false }] + +type format = Onnx | Nnet of nnet [@@deriving show { with_path = false }] type t = private { format: format; filename: string; } -type record = { - number_of_layers : int; - number_of_inputs: int; - number_of_outputs:int; - maximum_layer_size:int; - } - -(** Builds a model out of the given [filename], if possible. - - The model is inferred from the [filename] extension. -*) +(** Builds a model out of the given [filename], if possible. The model is + inferred from the [filename] extension for the Onnx case, while Nnet models + are parsed for metadata retrieval and conformity checks. *) val build: filename:string -> (t, string) Result.t - -(** Verifies the content of the given [filename] and checks - if the conditions are satisfied. -*) -val content_file:string -> unit - -(** The main function of the program *) -val my_fun:string -> (t, string) Result.t -val my_fun2:string -> unit - -- GitLab