From 06e8d80dc646a74b0237b57eeb0ef489327ea3fa Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 11 Mar 2021 21:00:43 +0100
Subject: [PATCH] Rework nnet model information retrieval, and use it for the
 nnet spec compatibility check.

---
 model.ml  | 328 ++++++++++++++++++++++++++++--------------------------
 model.mli |  48 ++++----
 2 files changed, 199 insertions(+), 177 deletions(-)

diff --git a/model.ml b/model.ml
index c31acd3..19ca382 100644
--- a/model.ml
+++ b/model.ml
@@ -4,171 +4,189 @@
 (*                                                                        *)
 (**************************************************************************)
 
-(******************************)
-(*This program verifies whether a file is a nnet file or not*)
-(******************************)
-
-
-(*This section verifies if the file has an nnet extension before verifying its content*)
-
-type format = Onnx | Nnet [@@deriving show { with_path = false}]
+open Base
+module Format = Caml.Format
+module Sys = Caml.Sys
+module Filename = Caml.Filename
+
+type nnet = {
+  n_layers : int;
+  n_inputs: int;
+  n_outputs: int;
+  max_layer_size: int;
+  layer_sizes: int list;
+  min_input_values: float list;
+  max_input_values: float list;
+  mean_values: float list * float;
+  range_values: float list * float;
+} [@@deriving show { with_path = false }]
+
+type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}]
 
 type t = {
   format: format;
   filename: string;
 }
 
-let is_nnet = ref 0;;
+
+(* Nnet format handling. *)
+
+let nnet_format_error s =
+  Error (Format.sprintf "Nnet format error: %s condition not satisfied." s)
+
+let nnet_delimiter = Str.regexp ","
+
+(* Parse a single Nnet format line: split line using [nnet_delimiter] as
+   delimiter, and convert each string into a number by means of converter [f]. *)
+let handle_nnet_line ~f line =
+  List.map
+    ~f:(fun s -> f (String.strip s))
+    (Str.split nnet_delimiter line)
+
+(* Skip the header part, ie comments, of the Nnet format. *)
+let handle_nnet_header filename in_channel =
+  let exception End_of_header in
+  let pos_in = ref (Stdlib.pos_in in_channel) in
+  try
+    while true do
+      let line = Stdlib.input_line in_channel in
+      if not (Str.string_match (Str.regexp "//") line 0)
+      then raise End_of_header
+      else pos_in := Stdlib.pos_in in_channel
+    done;
+    assert false
+  with
+  | End_of_header ->
+    (* At this point we have read one line past the header part: seek back. *)
+    Stdlib.seek_in in_channel !pos_in;
+    Ok ()
+  | End_of_file ->
+    Error (Format.sprintf "Nnet model not found in file `%s'." filename)
+
+(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
+let handle_nnet_basic_info in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    match handle_nnet_line ~f:Stdlib.int_of_string line with
+    | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
+      Ok (n_layers, n_inputs, n_outputs, max_layer_size)
+    | _ ->
+      nnet_format_error "second"
+  with End_of_file ->
+    nnet_format_error "second"
+
+(* Retrieve size of each layer, including inputs and outputs. *)
+let handle_nnet_layer_sizes n_layers in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
+    if List.length layer_sizes = (n_layers + 1)
+    then Ok layer_sizes
+    else nnet_format_error "third"
+  with End_of_file ->
+    nnet_format_error "third"
+
+(* Skip unused flag. *)
+let handle_nnet_unused_flag in_channel =
+  try
+    let _ = Stdlib.input_line in_channel in
+    Ok ()
+  with End_of_file ->
+    nnet_format_error "forth"
+
+(* Retrive minimum values of inputs. *)
+let handle_nnet_min_input_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length min_input_values = n_inputs
+    then Ok min_input_values
+    else nnet_format_error "fifth"
+  with End_of_file ->
+    nnet_format_error "fifth"
+
+(* Retrive maximum values of inputs. *)
+let handle_nnet_max_input_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length max_input_values = n_inputs
+    then Ok max_input_values
+    else nnet_format_error "sixth"
+  with End_of_file ->
+    nnet_format_error "sixth"
+
+(* Retrieve mean values of inputs and one value for all outputs. *)
+let handle_nnet_mean_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length mean_values = (n_inputs + 1)
+    then
+      let mean_input_values, mean_output_value =
+        List.split_n mean_values n_inputs
+      in
+      Ok (mean_input_values, List.hd_exn mean_output_value)
+    else
+      nnet_format_error "seventh"
+  with End_of_file ->
+    nnet_format_error "seventh"
+
+(* Retrieve range values of inputs and one value for all outputs. *)
+let handle_nnet_range_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length range_values = (n_inputs + 1)
+    then
+      let range_input_values, range_output_value =
+        List.split_n range_values n_inputs
+      in
+      Ok (range_input_values, List.hd_exn range_output_value)
+    else
+      nnet_format_error "eighth"
+  with End_of_file ->
+    nnet_format_error "eighth"
+
+(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see
+   https://github.com/sisl/NNet for details.) *)
+let retrieve_nnet_metadata filename =
+  let open Result in
+  let in_channel = Stdlib.open_in filename in
+  handle_nnet_header filename in_channel >>= fun () ->
+  handle_nnet_basic_info in_channel >>= fun (n_ls, n_ins, n_outs, max_l_size) ->
+  handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
+  handle_nnet_unused_flag in_channel >>= fun () ->
+  handle_nnet_min_input_values n_ins in_channel >>= fun min_input_values ->
+  handle_nnet_max_input_values n_ins in_channel >>= fun max_input_values ->
+  handle_nnet_mean_values n_ins in_channel >>= fun mean_values ->
+  handle_nnet_range_values n_ins in_channel >>= fun range_values ->
+  Stdlib.close_in in_channel;
+  Ok
+    { n_layers = n_ls;
+      n_inputs = n_ins;
+      n_outputs = n_outs;
+      max_layer_size = max_l_size;
+      layer_sizes;
+      min_input_values;
+      max_input_values;
+      mean_values;
+      range_values; }
+
+
+(* Generic model. *)
 
 let build ~filename =
+  let open Result in
   if Sys.file_exists filename
   then
-    let format =
-      if Filename.check_suffix filename "nnet"
-      then begin
-       is_nnet := 1;
-       Nnet
-      end 
-      else begin
-       is_nnet := 0;
-       print_endline "This is not an nnet file -extension";
-       Onnx
-      end
-    in 
-    print_newline();
+    begin
+      if Filename.check_suffix filename "onnx"
+      then Ok Onnx
+      else
+        retrieve_nnet_metadata filename >>= fun nnet ->
+        Ok (Nnet nnet)
+    end >>= fun format ->
     Ok { format; filename }
   else
     Error (Format.sprintf "No such file `%s'." filename)
-;;
-
-let flag = ref 0;;
-(*The type "record" is created to collect some parameters of the nnet file *)
-type record = {
-    number_of_layers : int;
-    number_of_inputs: int;
-    number_of_outputs:int; 
-    maximum_layer_size:int
-    };;
-
-
-(******************************)
-(*This function helps verify whether some conditions are satisfied to determine 
-whether the file has a nnet format.
-The conditions are : 
-- The file starts with header lines, it can be any number of lines so long as they begin with "//"
-- The following line contains four values: Number of layers, number of inputs, number of outputs, and maximum layer size
-- The number of inputs in that line (second element) is equal to that of the following line (first element) 
-- The flag line contains only one element
-- The minimum values line contains as many elements as there are inputs
-- The maximum values line contains as many elements as there are inputs
-- The mean values line contains as many elements as there are inputs plus an extra value for all outputs
-- The range values line contains as many elements as there are inputs plus an extra value for all outputs
-*)
-(******************************)
-
-
-let content_file filename = 
-    if !is_nnet = 1 then begin 
-    let in_channel = open_in filename in 
-    let c = ref 0 in 
-    let flag1 = ref 0 in 
-    let flag2 = ref 0 in 
-    let flag3 = ref 0 in 
-    let flag4 = ref 0 in 
-    let flag5 = ref 0 in 
-    let element1 = ref "" in
-    let element2 = ref "" in
-    let a = ref 0 in  
-    let cnt = ref 0 in 
-    try
-        while true do 
-            let line = input_line in_channel in 
-                cnt := !cnt+1;
-                (*Verifying the first condition*)
-                let cmp = Str.string_match (Str.regexp "//") line 0 in
-                    if cmp then begin
-                        flag := !flag +1;
-                        a := !flag;
-                    end
-                    else 
-                        try 
-                             (*Verifying the second condition*)
-                            let b = Str.split (Str.regexp ",") line in 
-                            let tmp= List.nth b 1 in 
-                            let num_layers = List.nth b 0 in 
-                            let num_outputs = List.nth b 2 in 
-                            let max_layer_size = List.nth b 3 in 
-                            element2 := tmp;
-                            let record2 = {number_of_layers = int_of_string(num_layers); number_of_inputs=int_of_string !element2;number_of_outputs=int_of_string num_outputs;maximum_layer_size=int_of_string max_layer_size} in 
-                            Format.printf "Number of layers: %d\n" record2.number_of_layers;
-                            Format.printf "Number of inputs: %d\n" record2.number_of_inputs;
-                            Format.printf "Number of outputs: %d\n" record2.number_of_outputs;
-                            Format.printf "Maximum layer size : %d\n" record2.maximum_layer_size ;
-                            c := List.length b; 
-                            raise Exit;
-                        with Exit -> 
-                            let line2 = input_line in_channel in 
-                            let d = Str.split (Str.regexp ",") line2 in 
-                            let tmp2 = List.nth d 0 in
-                            element1 := tmp2; 
-                        try
-                           (*Verifying the fourth condition using variable flag1*)
-                           let line3 = input_line in_channel in 
-                           let e = Str.split (Str.regexp ",") line3 in 
-                           flag1 := List.length e; 
-                           raise Exit;
-                        with Exit ->
-                            print_newline();
-                        try
-                           (*Verifying the fifth condition using variable flag2*)
-                           let line4 = input_line in_channel in 
-                           let f = Str.split (Str.regexp ",") line4 in 
-                           flag2 := List.length f; 
-                           raise Exit;
-                        with Exit ->
-                            (*Verifying the sixth condition using variable flag3*)
-                            let line5 = input_line in_channel in 
-                            let g = Str.split (Str.regexp ",") line5 in 
-                            flag3 := List.length g; 
-                        try
-                           (*Verifying the seventh condition using variable flag4*)
-                           let line6 = input_line in_channel in 
-                           let h = Str.split (Str.regexp ",") line6 in 
-                           flag4 := List.length h; 
-                           raise Exit;
-                        with Exit ->
-                            (*Verifying the eighth condition using variable flag5*)
-                            let line7 = input_line in_channel in 
-                            let i = Str.split (Str.regexp ",") line7 in 
-                            flag5 := List.length i; 
-                            raise End_of_file;
-        done
-    with End_of_file ->
-        if !flag = 0 then print_endline "not an nnet file"
-        else 
-            (*Verifying the third condition using the variable t*)
-            let t = compare element1 element2 in 
-            (*Verifying all of the conditions using the if branch*)
-            if (!c = 4 && t=0 && !flag1 = 1 && flag2=ref(int_of_string !element2) && flag3=ref(int_of_string !element2) && flag4=ref((int_of_string !element2)+1) && flag5=ref((int_of_string !element2)+1)) then
-                print_endline "nnet file"
-            else print_endline "it is not an nnet file";
-    close_in in_channel;
-    end 
-;;
-
-let my_fun filename = 
-        build ~filename;;
-let my_fun2 filename =
-        content_file filename;
-;;
-
-
-(*This command is used to create an interface with the user such that they can insert the name of the file
-to be tested.
-In order to compile the code, use : 
-ocamlfind ocamlopt -package str -o (name_of_the_executable) (name_of_the_program.ml) -linkpkg 
-To run the executable, use : 
-./(name_of_the_executable) (name_of_the_nnet_file.nnet) *)
-
-my_fun Sys.argv.(1);;
-my_fun2 Sys.argv.(1);;
diff --git a/model.mli b/model.mli
index 26993da..fdc61c7 100644
--- a/model.mli
+++ b/model.mli
@@ -4,32 +4,36 @@
 (*                                                                        *)
 (**************************************************************************)
 
-type format = Onnx | Nnet [@@deriving show { with_path = false }]
+(** Nnet model metadata. *)
+type nnet = private {
+  (** Number of layers. *)
+  n_layers : int;
+  (** Number of inputs. *)
+  n_inputs: int;
+  (** Number of outputs. *)
+  n_outputs: int;
+  (** Maximum layer size. *)
+  max_layer_size: int;
+  (** Size of each layer. *)
+  layer_sizes: int list;
+  (** Minimum values of inputs. *)
+  min_input_values: float list;
+  (** Maximum values of inputs. *)
+  max_input_values: float list;
+  (** Mean values of inputs and one value for all outputs. *)
+  mean_values: float list * float;
+  (** Range values of inputs and one value for all outputs. *)
+  range_values: float list * float;
+} [@@deriving show { with_path = false }]
+
+type format = Onnx | Nnet of nnet [@@deriving show { with_path = false }]
 
 type t = private {
   format: format;
   filename: string;
 }
 
-type record = {
-    number_of_layers : int;
-    number_of_inputs: int;
-    number_of_outputs:int; 
-    maximum_layer_size:int;
-    }
-
-(** Builds a model out of the given [filename], if possible.
-
-    The model is inferred from the [filename] extension.
-*)
+(** Builds a model out of the given [filename], if possible. The model is
+    inferred from the [filename] extension for the Onnx case, while Nnet models
+    are parsed for metadata retrieval and conformity checks. *)
 val build: filename:string -> (t, string) Result.t
-
-(** Verifies the content of the given [filename] and checks 
-    if the conditions are satisfied.
-*)
-val content_file:string -> unit
-
-(** The main function of the program *)
-val my_fun:string -> (t, string) Result.t
-val my_fun2:string -> unit
-
-- 
GitLab