Skip to content
Snippets Groups Projects
model.ml 6.54 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of Caisar.                                          *)
(*                                                                        *)
(**************************************************************************)

open Base

module Format = Caml.Format
module Sys = Caml.Sys
module Filename = Caml.Filename

type nnet = {
  n_layers : int;
  n_inputs: int;
  n_outputs: int;
  max_layer_size: int;
  layer_sizes: int list;
  min_input_values: float list;
  max_input_values: float list;
  mean_values: float list * float;
  range_values: float list * float;
} [@@deriving show { with_path = false }]

type format = Onnx | Nnet of nnet [@@deriving show { with_path = false}]

type t = {
  format: format;
  filename: string;
}


(* Nnet format handling. *)

let nnet_format_error s =
  Error (Format.sprintf "Nnet format error: %s condition not satisfied." s)

let nnet_delimiter = Str.regexp ","

(* Parse a single Nnet format line: split line using [nnet_delimiter] as
   delimiter, and convert each string into a number by means of converter [f]. *)
let handle_nnet_line ~f line =
  List.filter_map
    ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
    (Str.split nnet_delimiter line)

(* Skip the header part, ie comments, of the Nnet format. *)
let handle_nnet_header filename in_channel =
  let exception End_of_header in
  let pos_in = ref (Stdlib.pos_in in_channel) in
  try
    while true do
      let line = Stdlib.input_line in_channel in
      if not (Str.string_match (Str.regexp "//") line 0)
      then raise End_of_header
      else pos_in := Stdlib.pos_in in_channel
    done;
    assert false
  with
  | End_of_header ->
    (* At this point we have read one line past the header part: seek back. *)
    Stdlib.seek_in in_channel !pos_in;
    Ok ()
  | End_of_file ->
    Error (Format.sprintf "Nnet model not found in file `%s'." filename)

(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
  try
    let line = Stdlib.input_line in_channel in
    match handle_nnet_line ~f:Stdlib.int_of_string line with
    | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
      Ok (n_layers, n_inputs, n_outputs, max_layer_size)
    | _ ->
      nnet_format_error "second"
  with End_of_file ->
    nnet_format_error "second"

(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
  try
    let line = Stdlib.input_line in_channel in
    let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
    if List.length layer_sizes = (n_layers + 1)
    then Ok layer_sizes
    else nnet_format_error "third"
  with End_of_file ->
    nnet_format_error "third"

(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
  try
    let _ = Stdlib.input_line in_channel in
    Ok ()
  with End_of_file ->
    nnet_format_error "forth"

(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
  try
    let line = Stdlib.input_line in_channel in
    let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
    if List.length min_input_values = n_inputs
    then Ok min_input_values
    else nnet_format_error "fifth"
  with End_of_file ->
    nnet_format_error "fifth"

(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
  try
    let line = Stdlib.input_line in_channel in
    let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
    if List.length max_input_values = n_inputs
    then Ok max_input_values
    else nnet_format_error "sixth"
  with End_of_file ->
    nnet_format_error "sixth"

(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
  try
    let line = Stdlib.input_line in_channel in
    let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
    if List.length mean_values = (n_inputs + 1)
    then
      let mean_input_values, mean_output_value =
        List.split_n mean_values n_inputs
      in
      Ok (mean_input_values, List.hd_exn mean_output_value)
    else
      nnet_format_error "seventh"
  with End_of_file ->
    nnet_format_error "seventh"

(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
  try
    let line = Stdlib.input_line in_channel in
    let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
    if List.length range_values = (n_inputs + 1)
    then
      let range_input_values, range_output_value =
        List.split_n range_values n_inputs
      in
      Ok (range_input_values, List.hd_exn range_output_value)
    else
      nnet_format_error "eighth"
  with End_of_file ->
    nnet_format_error "eighth"

(* Retrieves [filename] Nnet model metadata wrt Nnet format specification (see
   https://github.com/sisl/NNet for details.) *)
let retrieve_nnet_metadata filename =
  let open Result in
  let in_channel = Stdlib.open_in filename in
  try
    handle_nnet_header filename in_channel >>= fun () ->
    handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
    handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
    handle_nnet_unused_flag in_channel >>= fun () ->
    handle_nnet_min_input_values n_is in_channel >>= fun min_input_values ->
    handle_nnet_max_input_values n_is in_channel >>= fun max_input_values ->
    handle_nnet_mean_values n_is in_channel >>= fun mean_values ->
    handle_nnet_range_values n_is in_channel >>= fun range_values ->
    Stdlib.close_in in_channel;
    Ok
      { n_layers = n_ls;
        n_inputs = n_is;
        n_outputs = n_os;
        max_layer_size = max_l_size;
        layer_sizes;
        min_input_values;
        max_input_values;
        mean_values;
        range_values; }
  with Failure msg ->
    Error (Format.sprintf "Unexpected error: %s." msg)

(* Generic model. *)

let build ~filename =
  let open Result in
  Logs.info (fun m -> m "Checking format of model file `%s'." filename);
  if Sys.file_exists filename
  then
    begin
      if Filename.check_suffix filename "onnx"
      then Ok Onnx
      else
        retrieve_nnet_metadata filename >>= fun nnet ->
        Ok (Nnet nnet)
    end >>= fun format ->
    Logs.info (fun m ->
        m "Model format recognized as `%s'."
          (match format with Onnx -> "ONNX" | Nnet _ -> "NNet"));
    Ok { format; filename }
  else
    Error (Format.sprintf "No such file `%s'." filename)