diff --git a/model.ml b/model.ml index e13dbcc0257adafeedac656e599fffb068209f78..8fc7eb5e6e5fa2e3e858c2a4198d5974f56a2920 100644 --- a/model.ml +++ b/model.ml @@ -4,21 +4,191 @@ (* *) (**************************************************************************) -type format = Onnx | Nnet [@@deriving show { with_path = false}] +open Base +module Format = Caml.Format +module Sys = Caml.Sys +module Filename = Caml.Filename + +type nnet = { + n_layers : int; + n_inputs: int; + n_outputs: int; + max_layer_size: int; + layer_sizes: int list; + min_input_values: float list; + max_input_values: float list; + mean_values: float list * float; + range_values: float list * float; +} [@@deriving show { with_path = false }] + +type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}] type t = { format: format; filename: string; } + +(* Nnet format handling. *) + +let nnet_format_error s = + Error (Format.sprintf "Nnet format error: %s condition not satisfied." s) + +let nnet_delimiter = Str.regexp "," + +(* Parse a single Nnet format line: split line using [nnet_delimiter] as + delimiter, and convert each string into a number by means of converter [f]. *) +let handle_nnet_line ~f line = + List.map + ~f:(fun s -> f (String.strip s)) + (Str.split nnet_delimiter line) + +(* Skip the header part, ie comments, of the Nnet format. *) +let handle_nnet_header filename in_channel = + let exception End_of_header in + let pos_in = ref (Stdlib.pos_in in_channel) in + try + while true do + let line = Stdlib.input_line in_channel in + if not (Str.string_match (Str.regexp "//") line 0) + then raise End_of_header + else pos_in := Stdlib.pos_in in_channel + done; + assert false + with + | End_of_header -> + (* At this point we have read one line past the header part: seek back. *) + Stdlib.seek_in in_channel !pos_in; + Ok () + | End_of_file -> + Error (Format.sprintf "Nnet model not found in file `%s'." filename) + +(* Retrieve number of layers, inputs, outputs and maximum layer size. *) +let handle_nnet_basic_info in_channel = + try + let line = Stdlib.input_line in_channel in + match handle_nnet_line ~f:Stdlib.int_of_string line with + | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> + Ok (n_layers, n_inputs, n_outputs, max_layer_size) + | _ -> + nnet_format_error "second" + with End_of_file -> + nnet_format_error "second" + +(* Retrieve size of each layer, including inputs and outputs. *) +let handle_nnet_layer_sizes n_layers in_channel = + try + let line = Stdlib.input_line in_channel in + let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in + if List.length layer_sizes = (n_layers + 1) + then Ok layer_sizes + else nnet_format_error "third" + with End_of_file -> + nnet_format_error "third" + +(* Skip unused flag. *) +let handle_nnet_unused_flag in_channel = + try + let _ = Stdlib.input_line in_channel in + Ok () + with End_of_file -> + nnet_format_error "forth" + +(* Retrive minimum values of inputs. *) +let handle_nnet_min_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length min_input_values = n_inputs + then Ok min_input_values + else nnet_format_error "fifth" + with End_of_file -> + nnet_format_error "fifth" + +(* Retrive maximum values of inputs. *) +let handle_nnet_max_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length max_input_values = n_inputs + then Ok max_input_values + else nnet_format_error "sixth" + with End_of_file -> + nnet_format_error "sixth" + +(* Retrieve mean values of inputs and one value for all outputs. *) +let handle_nnet_mean_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length mean_values = (n_inputs + 1) + then + let mean_input_values, mean_output_value = + List.split_n mean_values n_inputs + in + Ok (mean_input_values, List.hd_exn mean_output_value) + else + nnet_format_error "seventh" + with End_of_file -> + nnet_format_error "seventh" + +(* Retrieve range values of inputs and one value for all outputs. *) +let handle_nnet_range_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length range_values = (n_inputs + 1) + then + let range_input_values, range_output_value = + List.split_n range_values n_inputs + in + Ok (range_input_values, List.hd_exn range_output_value) + else + nnet_format_error "eighth" + with End_of_file -> + nnet_format_error "eighth" + +(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see + https://github.com/sisl/NNet for details.) *) +let retrieve_nnet_metadata filename = + let open Result in + let in_channel = Stdlib.open_in filename in + try + handle_nnet_header filename in_channel >>= fun () -> + handle_nnet_basic_info in_channel >>= fun (n_ls, n_ins, n_outs, max_l_size) -> + handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> + handle_nnet_unused_flag in_channel >>= fun () -> + handle_nnet_min_input_values n_ins in_channel >>= fun min_input_values -> + handle_nnet_max_input_values n_ins in_channel >>= fun max_input_values -> + handle_nnet_mean_values n_ins in_channel >>= fun mean_values -> + handle_nnet_range_values n_ins in_channel >>= fun range_values -> + Stdlib.close_in in_channel; + Ok + { n_layers = n_ls; + n_inputs = n_ins; + n_outputs = n_outs; + max_layer_size = max_l_size; + layer_sizes; + min_input_values; + max_input_values; + mean_values; + range_values; } + with Failure msg -> + Error (Format.sprintf "Unexpected error: %s." msg) + +(* Generic model. *) + let build ~filename = + let open Result in if Sys.file_exists filename then - let format = - if Filename.check_suffix filename "nnet" - then Nnet - else Onnx - in + begin + if Filename.check_suffix filename "onnx" + then Ok Onnx + else + retrieve_nnet_metadata filename >>= fun nnet -> + Ok (Nnet nnet) + end >>= fun format -> Ok { format; filename } else Error (Format.sprintf "No such file `%s'." filename) diff --git a/model.mli b/model.mli index 45687d863e14af274565c307a74ac073145de2b1..fdc61c7d7a6c0c7ea8194760e1aad4ea5987f16a 100644 --- a/model.mli +++ b/model.mli @@ -4,15 +4,36 @@ (* *) (**************************************************************************) -type format = Onnx | Nnet [@@deriving show { with_path = false }] +(** Nnet model metadata. *) +type nnet = private { + (** Number of layers. *) + n_layers : int; + (** Number of inputs. *) + n_inputs: int; + (** Number of outputs. *) + n_outputs: int; + (** Maximum layer size. *) + max_layer_size: int; + (** Size of each layer. *) + layer_sizes: int list; + (** Minimum values of inputs. *) + min_input_values: float list; + (** Maximum values of inputs. *) + max_input_values: float list; + (** Mean values of inputs and one value for all outputs. *) + mean_values: float list * float; + (** Range values of inputs and one value for all outputs. *) + range_values: float list * float; +} [@@deriving show { with_path = false }] + +type format = Onnx | Nnet of nnet [@@deriving show { with_path = false }] type t = private { format: format; filename: string; } -(** Builds a model out of the given [filename], if possible. - - The model is inferred from the [filename] extension. -*) +(** Builds a model out of the given [filename], if possible. The model is + inferred from the [filename] extension for the Onnx case, while Nnet models + are parsed for metadata retrieval and conformity checks. *) val build: filename:string -> (t, string) Result.t