Skip to content
Snippets Groups Projects
Commit 06e8d80d authored by Michele Alberti's avatar Michele Alberti
Browse files

Rework nnet model information retrieval, and use it for the nnet spec compatibility check.

parent dd34aff4
No related branches found
No related tags found
No related merge requests found
......@@ -4,171 +4,189 @@
(* *)
(**************************************************************************)
(******************************)
(*This program verifies whether a file is a nnet file or not*)
(******************************)
(*This section verifies if the file has an nnet extension before verifying its content*)
type format = Onnx | Nnet [@@deriving show { with_path = false}]
open Base
module Format = Caml.Format
module Sys = Caml.Sys
module Filename = Caml.Filename
type nnet = {
n_layers : int;
n_inputs: int;
n_outputs: int;
max_layer_size: int;
layer_sizes: int list;
min_input_values: float list;
max_input_values: float list;
mean_values: float list * float;
range_values: float list * float;
} [@@deriving show { with_path = false }]
type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}]
type t = {
format: format;
filename: string;
}
let is_nnet = ref 0;;
(* Nnet format handling. *)
let nnet_format_error s =
Error (Format.sprintf "Nnet format error: %s condition not satisfied." s)
let nnet_delimiter = Str.regexp ","
(* Parse a single Nnet format line: split line using [nnet_delimiter] as
delimiter, and convert each string into a number by means of converter [f]. *)
let handle_nnet_line ~f line =
List.map
~f:(fun s -> f (String.strip s))
(Str.split nnet_delimiter line)
(* Skip the header part, ie comments, of the Nnet format. *)
let handle_nnet_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "Nnet model not found in file `%s'." filename)
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
try
let line = Stdlib.input_line in_channel in
match handle_nnet_line ~f:Stdlib.int_of_string line with
| [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| _ ->
nnet_format_error "second"
with End_of_file ->
nnet_format_error "second"
(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
try
let line = Stdlib.input_line in_channel in
let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
if List.length layer_sizes = (n_layers + 1)
then Ok layer_sizes
else nnet_format_error "third"
with End_of_file ->
nnet_format_error "third"
(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
try
let _ = Stdlib.input_line in_channel in
Ok ()
with End_of_file ->
nnet_format_error "forth"
(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length min_input_values = n_inputs
then Ok min_input_values
else nnet_format_error "fifth"
with End_of_file ->
nnet_format_error "fifth"
(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length max_input_values = n_inputs
then Ok max_input_values
else nnet_format_error "sixth"
with End_of_file ->
nnet_format_error "sixth"
(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length mean_values = (n_inputs + 1)
then
let mean_input_values, mean_output_value =
List.split_n mean_values n_inputs
in
Ok (mean_input_values, List.hd_exn mean_output_value)
else
nnet_format_error "seventh"
with End_of_file ->
nnet_format_error "seventh"
(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length range_values = (n_inputs + 1)
then
let range_input_values, range_output_value =
List.split_n range_values n_inputs
in
Ok (range_input_values, List.hd_exn range_output_value)
else
nnet_format_error "eighth"
with End_of_file ->
nnet_format_error "eighth"
(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see
https://github.com/sisl/NNet for details.) *)
let retrieve_nnet_metadata filename =
let open Result in
let in_channel = Stdlib.open_in filename in
handle_nnet_header filename in_channel >>= fun () ->
handle_nnet_basic_info in_channel >>= fun (n_ls, n_ins, n_outs, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () ->
handle_nnet_min_input_values n_ins in_channel >>= fun min_input_values ->
handle_nnet_max_input_values n_ins in_channel >>= fun max_input_values ->
handle_nnet_mean_values n_ins in_channel >>= fun mean_values ->
handle_nnet_range_values n_ins in_channel >>= fun range_values ->
Stdlib.close_in in_channel;
Ok
{ n_layers = n_ls;
n_inputs = n_ins;
n_outputs = n_outs;
max_layer_size = max_l_size;
layer_sizes;
min_input_values;
max_input_values;
mean_values;
range_values; }
(* Generic model. *)
let build ~filename =
let open Result in
if Sys.file_exists filename
then
let format =
if Filename.check_suffix filename "nnet"
then begin
is_nnet := 1;
Nnet
end
else begin
is_nnet := 0;
print_endline "This is not an nnet file -extension";
Onnx
end
in
print_newline();
begin
if Filename.check_suffix filename "onnx"
then Ok Onnx
else
retrieve_nnet_metadata filename >>= fun nnet ->
Ok (Nnet nnet)
end >>= fun format ->
Ok { format; filename }
else
Error (Format.sprintf "No such file `%s'." filename)
;;
let flag = ref 0;;
(*The type "record" is created to collect some parameters of the nnet file *)
type record = {
number_of_layers : int;
number_of_inputs: int;
number_of_outputs:int;
maximum_layer_size:int
};;
(******************************)
(*This function helps verify whether some conditions are satisfied to determine
whether the file has a nnet format.
The conditions are :
- The file starts with header lines, it can be any number of lines so long as they begin with "//"
- The following line contains four values: Number of layers, number of inputs, number of outputs, and maximum layer size
- The number of inputs in that line (second element) is equal to that of the following line (first element)
- The flag line contains only one element
- The minimum values line contains as many elements as there are inputs
- The maximum values line contains as many elements as there are inputs
- The mean values line contains as many elements as there are inputs plus an extra value for all outputs
- The range values line contains as many elements as there are inputs plus an extra value for all outputs
*)
(******************************)
let content_file filename =
if !is_nnet = 1 then begin
let in_channel = open_in filename in
let c = ref 0 in
let flag1 = ref 0 in
let flag2 = ref 0 in
let flag3 = ref 0 in
let flag4 = ref 0 in
let flag5 = ref 0 in
let element1 = ref "" in
let element2 = ref "" in
let a = ref 0 in
let cnt = ref 0 in
try
while true do
let line = input_line in_channel in
cnt := !cnt+1;
(*Verifying the first condition*)
let cmp = Str.string_match (Str.regexp "//") line 0 in
if cmp then begin
flag := !flag +1;
a := !flag;
end
else
try
(*Verifying the second condition*)
let b = Str.split (Str.regexp ",") line in
let tmp= List.nth b 1 in
let num_layers = List.nth b 0 in
let num_outputs = List.nth b 2 in
let max_layer_size = List.nth b 3 in
element2 := tmp;
let record2 = {number_of_layers = int_of_string(num_layers); number_of_inputs=int_of_string !element2;number_of_outputs=int_of_string num_outputs;maximum_layer_size=int_of_string max_layer_size} in
Format.printf "Number of layers: %d\n" record2.number_of_layers;
Format.printf "Number of inputs: %d\n" record2.number_of_inputs;
Format.printf "Number of outputs: %d\n" record2.number_of_outputs;
Format.printf "Maximum layer size : %d\n" record2.maximum_layer_size ;
c := List.length b;
raise Exit;
with Exit ->
let line2 = input_line in_channel in
let d = Str.split (Str.regexp ",") line2 in
let tmp2 = List.nth d 0 in
element1 := tmp2;
try
(*Verifying the fourth condition using variable flag1*)
let line3 = input_line in_channel in
let e = Str.split (Str.regexp ",") line3 in
flag1 := List.length e;
raise Exit;
with Exit ->
print_newline();
try
(*Verifying the fifth condition using variable flag2*)
let line4 = input_line in_channel in
let f = Str.split (Str.regexp ",") line4 in
flag2 := List.length f;
raise Exit;
with Exit ->
(*Verifying the sixth condition using variable flag3*)
let line5 = input_line in_channel in
let g = Str.split (Str.regexp ",") line5 in
flag3 := List.length g;
try
(*Verifying the seventh condition using variable flag4*)
let line6 = input_line in_channel in
let h = Str.split (Str.regexp ",") line6 in
flag4 := List.length h;
raise Exit;
with Exit ->
(*Verifying the eighth condition using variable flag5*)
let line7 = input_line in_channel in
let i = Str.split (Str.regexp ",") line7 in
flag5 := List.length i;
raise End_of_file;
done
with End_of_file ->
if !flag = 0 then print_endline "not an nnet file"
else
(*Verifying the third condition using the variable t*)
let t = compare element1 element2 in
(*Verifying all of the conditions using the if branch*)
if (!c = 4 && t=0 && !flag1 = 1 && flag2=ref(int_of_string !element2) && flag3=ref(int_of_string !element2) && flag4=ref((int_of_string !element2)+1) && flag5=ref((int_of_string !element2)+1)) then
print_endline "nnet file"
else print_endline "it is not an nnet file";
close_in in_channel;
end
;;
let my_fun filename =
build ~filename;;
let my_fun2 filename =
content_file filename;
;;
(*This command is used to create an interface with the user such that they can insert the name of the file
to be tested.
In order to compile the code, use :
ocamlfind ocamlopt -package str -o (name_of_the_executable) (name_of_the_program.ml) -linkpkg
To run the executable, use :
./(name_of_the_executable) (name_of_the_nnet_file.nnet) *)
my_fun Sys.argv.(1);;
my_fun2 Sys.argv.(1);;
......@@ -4,32 +4,36 @@
(* *)
(**************************************************************************)
type format = Onnx | Nnet [@@deriving show { with_path = false }]
(** Nnet model metadata. *)
type nnet = private {
(** Number of layers. *)
n_layers : int;
(** Number of inputs. *)
n_inputs: int;
(** Number of outputs. *)
n_outputs: int;
(** Maximum layer size. *)
max_layer_size: int;
(** Size of each layer. *)
layer_sizes: int list;
(** Minimum values of inputs. *)
min_input_values: float list;
(** Maximum values of inputs. *)
max_input_values: float list;
(** Mean values of inputs and one value for all outputs. *)
mean_values: float list * float;
(** Range values of inputs and one value for all outputs. *)
range_values: float list * float;
} [@@deriving show { with_path = false }]
type format = Onnx | Nnet of nnet [@@deriving show { with_path = false }]
type t = private {
format: format;
filename: string;
}
type record = {
number_of_layers : int;
number_of_inputs: int;
number_of_outputs:int;
maximum_layer_size:int;
}
(** Builds a model out of the given [filename], if possible.
The model is inferred from the [filename] extension.
*)
(** Builds a model out of the given [filename], if possible. The model is
inferred from the [filename] extension for the Onnx case, while Nnet models
are parsed for metadata retrieval and conformity checks. *)
val build: filename:string -> (t, string) Result.t
(** Verifies the content of the given [filename] and checks
if the conditions are satisfied.
*)
val content_file:string -> unit
(** The main function of the program *)
val my_fun:string -> (t, string) Result.t
val my_fun2:string -> unit
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment