-
François Bobot authoredFrançois Bobot authored
model.ml 6.54 KiB
(**************************************************************************)
(* *)
(* This file is part of Caisar. *)
(* *)
(**************************************************************************)
open Base
module Format = Caml.Format
module Sys = Caml.Sys
module Filename = Caml.Filename
type nnet = {
n_layers : int;
n_inputs: int;
n_outputs: int;
max_layer_size: int;
layer_sizes: int list;
min_input_values: float list;
max_input_values: float list;
mean_values: float list * float;
range_values: float list * float;
} [@@deriving show { with_path = false }]
type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}]
type t = {
format: format;
filename: string;
}
(* Nnet format handling. *)
let nnet_format_error s =
Error (Format.sprintf "Nnet format error: %s condition not satisfied." s)
let nnet_delimiter = Str.regexp ","
(* Parse a single Nnet format line: split line using [nnet_delimiter] as
delimiter, and convert each string into a number by means of converter [f]. *)
let handle_nnet_line ~f line =
List.filter_map
~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Str.split nnet_delimiter line)
(* Skip the header part, ie comments, of the Nnet format. *)
let handle_nnet_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "Nnet model not found in file `%s'." filename)
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
try
let line = Stdlib.input_line in_channel in
match handle_nnet_line ~f:Stdlib.int_of_string line with
| [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| _ ->
nnet_format_error "second"
with End_of_file ->
nnet_format_error "second"
(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
try
let line = Stdlib.input_line in_channel in
let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
if List.length layer_sizes = (n_layers + 1)
then Ok layer_sizes
else nnet_format_error "third"
with End_of_file ->
nnet_format_error "third"
(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
try
let _ = Stdlib.input_line in_channel in
Ok ()
with End_of_file ->
nnet_format_error "forth"
(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length min_input_values = n_inputs
then Ok min_input_values
else nnet_format_error "fifth"
with End_of_file ->
nnet_format_error "fifth"
(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length max_input_values = n_inputs
then Ok max_input_values
else nnet_format_error "sixth"
with End_of_file ->
nnet_format_error "sixth"
(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length mean_values = (n_inputs + 1)
then
let mean_input_values, mean_output_value =
List.split_n mean_values n_inputs
in
Ok (mean_input_values, List.hd_exn mean_output_value)
else
nnet_format_error "seventh"
with End_of_file ->
nnet_format_error "seventh"
(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
try
let line = Stdlib.input_line in_channel in
let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length range_values = (n_inputs + 1)
then
let range_input_values, range_output_value =
List.split_n range_values n_inputs
in
Ok (range_input_values, List.hd_exn range_output_value)
else
nnet_format_error "eighth"
with End_of_file ->
nnet_format_error "eighth"
(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see
https://github.com/sisl/NNet for details.) *)
let retrieve_nnet_metadata filename =
let open Result in
let in_channel = Stdlib.open_in filename in
try
handle_nnet_header filename in_channel >>= fun () ->
handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () ->
handle_nnet_min_input_values n_is in_channel >>= fun min_input_values ->
handle_nnet_max_input_values n_is in_channel >>= fun max_input_values ->
handle_nnet_mean_values n_is in_channel >>= fun mean_values ->
handle_nnet_range_values n_is in_channel >>= fun range_values ->
Stdlib.close_in in_channel;
Ok
{ n_layers = n_ls;
n_inputs = n_is;
n_outputs = n_os;
max_layer_size = max_l_size;
layer_sizes;
min_input_values;
max_input_values;
mean_values;
range_values; }
with Failure msg ->
Error (Format.sprintf "Unexpected error: %s." msg)
(* Generic model. *)
let build ~filename =
let open Result in
Logs.info (fun m -> m "Checking format of model file `%s'." filename);
if Sys.file_exists filename
then
begin
if Filename.check_suffix filename "onnx"
then Ok Onnx
else
retrieve_nnet_metadata filename >>= fun nnet ->
Ok (Nnet nnet)
end >>= fun format ->
Logs.info (fun m ->
m "Model format recognized as `%s'."
(match format with Onnx -> "ONNX" | Nnet _ -> "NNet"));
Ok { format; filename }
else
Error (Format.sprintf "No such file `%s'." filename)