Skip to content
Snippets Groups Projects
Commit 2521a026 authored by Michele Alberti's avatar Michele Alberti
Browse files

Merge branch 'projet_nnet' into 'master'

Modification of model.ml and model.mli files (actually recognize nnet format)

Closes #3

See merge request malberti/caisar!1
parents 102562e3 d4a38131
No related branches found
No related tags found
No related merge requests found
...@@ -4,21 +4,191 @@ ...@@ -4,21 +4,191 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
type format = Onnx | Nnet [@@deriving show { with_path = false}] open Base
module Format = Caml.Format
module Sys = Caml.Sys
module Filename = Caml.Filename
type nnet = {
n_layers : int;
n_inputs: int;
n_outputs: int;
max_layer_size: int;
layer_sizes: int list;
min_input_values: float list;
max_input_values: float list;
mean_values: float list * float;
range_values: float list * float;
} [@@deriving show { with_path = false }]
type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}]
type t = { type t = {
format: format; format: format;
filename: string; filename: string;
} }
(* Nnet format handling. *)
let nnet_format_error s =
Error (Format.sprintf "Nnet format error: %s condition not satisfied." s)
let nnet_delimiter = Str.regexp ","
(* Parse a single Nnet format line: split line using [nnet_delimiter] as
delimiter, and convert each string into a number by means of converter [f]. *)
let handle_nnet_line ~f line =
List.map
~f:(fun s -> f (String.strip s))
(Str.split nnet_delimiter line)
(* Skip the header part, ie comments, of the Nnet format. *)
let handle_nnet_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "Nnet model not found in file `%s'." filename)
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
try
let line = Stdlib.input_line in_channel in
match handle_nnet_line ~f:Stdlib.int_of_string line with
| [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| _ ->
nnet_format_error "second"
with End_of_file ->
nnet_format_error "second"
(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
try
let line = Stdlib.input_line in_channel in
let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
if List.length layer_sizes = (n_layers + 1)
then Ok layer_sizes
else nnet_format_error "third"
with End_of_file ->
nnet_format_error "third"
(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
try
let _ = Stdlib.input_line in_channel in
Ok ()
with End_of_file ->
nnet_format_error "forth"
(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length min_input_values = n_inputs
then Ok min_input_values
else nnet_format_error "fifth"
with End_of_file ->
nnet_format_error "fifth"
(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length max_input_values = n_inputs
then Ok max_input_values
else nnet_format_error "sixth"
with End_of_file ->
nnet_format_error "sixth"
(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length mean_values = (n_inputs + 1)
then
let mean_input_values, mean_output_value =
List.split_n mean_values n_inputs
in
Ok (mean_input_values, List.hd_exn mean_output_value)
else
nnet_format_error "seventh"
with End_of_file ->
nnet_format_error "seventh"
(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length range_values = (n_inputs + 1)
then
let range_input_values, range_output_value =
List.split_n range_values n_inputs
in
Ok (range_input_values, List.hd_exn range_output_value)
else
nnet_format_error "eighth"
with End_of_file ->
nnet_format_error "eighth"
(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see
https://github.com/sisl/NNet for details.) *)
let retrieve_nnet_metadata filename =
let open Result in
let in_channel = Stdlib.open_in filename in
try
handle_nnet_header filename in_channel >>= fun () ->
handle_nnet_basic_info in_channel >>= fun (n_ls, n_ins, n_outs, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () ->
handle_nnet_min_input_values n_ins in_channel >>= fun min_input_values ->
handle_nnet_max_input_values n_ins in_channel >>= fun max_input_values ->
handle_nnet_mean_values n_ins in_channel >>= fun mean_values ->
handle_nnet_range_values n_ins in_channel >>= fun range_values ->
Stdlib.close_in in_channel;
Ok
{ n_layers = n_ls;
n_inputs = n_ins;
n_outputs = n_outs;
max_layer_size = max_l_size;
layer_sizes;
min_input_values;
max_input_values;
mean_values;
range_values; }
with Failure msg ->
Error (Format.sprintf "Unexpected error: %s." msg)
(* Generic model. *)
let build ~filename = let build ~filename =
let open Result in
if Sys.file_exists filename if Sys.file_exists filename
then then
let format = begin
if Filename.check_suffix filename "nnet" if Filename.check_suffix filename "onnx"
then Nnet then Ok Onnx
else Onnx else
in retrieve_nnet_metadata filename >>= fun nnet ->
Ok (Nnet nnet)
end >>= fun format ->
Ok { format; filename } Ok { format; filename }
else else
Error (Format.sprintf "No such file `%s'." filename) Error (Format.sprintf "No such file `%s'." filename)
...@@ -4,15 +4,36 @@ ...@@ -4,15 +4,36 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
type format = Onnx | Nnet [@@deriving show { with_path = false }] (** Nnet model metadata. *)
type nnet = private {
(** Number of layers. *)
n_layers : int;
(** Number of inputs. *)
n_inputs: int;
(** Number of outputs. *)
n_outputs: int;
(** Maximum layer size. *)
max_layer_size: int;
(** Size of each layer. *)
layer_sizes: int list;
(** Minimum values of inputs. *)
min_input_values: float list;
(** Maximum values of inputs. *)
max_input_values: float list;
(** Mean values of inputs and one value for all outputs. *)
mean_values: float list * float;
(** Range values of inputs and one value for all outputs. *)
range_values: float list * float;
} [@@deriving show { with_path = false }]
type format = Onnx | Nnet of nnet [@@deriving show { with_path = false }]
type t = private { type t = private {
format: format; format: format;
filename: string; filename: string;
} }
(** Builds a model out of the given [filename], if possible. (** Builds a model out of the given [filename], if possible. The model is
inferred from the [filename] extension for the Onnx case, while Nnet models
The model is inferred from the [filename] extension. are parsed for metadata retrieval and conformity checks. *)
*)
val build: filename:string -> (t, string) Result.t val build: filename:string -> (t, string) Result.t
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment