Skip to content
Snippets Groups Projects
Commit 377c194e authored by Michele Alberti's avatar Michele Alberti
Browse files

Use csv library to parse nnet model wrt CSV format.

parent 9aceca1a
No related branches found
No related tags found
No related merge requests found
(library (library
(name nnet) (name nnet)
(public_name nnet) (public_name nnet)
(libraries base) (libraries base csv)
(synopsis "NNet parser")) (synopsis "NNet parser"))
...@@ -19,25 +19,23 @@ type t = { ...@@ -19,25 +19,23 @@ type t = {
max_input_values : float list; max_input_values : float list;
mean_values : float list * float; mean_values : float list * float;
range_values : float list * float; range_values : float list * float;
weights_biases : float list list;
} }
[@@deriving show { with_path = false }]
(* NNet format handling. *) (* NNet format handling. *)
let nnet_format_error s = let nnet_format_error s =
Error (Format.sprintf "NNet format error: %s condition not satisfied." s) Error (Format.sprintf "NNet format error: %s condition not satisfied." s)
let nnet_delimiter = Str.regexp "," (* Parse a single NNet format line: split line wrt CSV format, and convert each
string into a number by means of converter [f]. *)
(* Parse a single NNet format line: split line using [nnet_delimiter] as let handle_nnet_line ~f in_channel =
delimiter, and convert each string into a number by means of converter [f]. *)
let handle_nnet_line ~f line =
List.filter_map List.filter_map
~f:(fun s -> try Some (f (String.strip s)) with _ -> None) ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Str.split nnet_delimiter line) (Csv.next in_channel)
(* Skip the header part, ie comments, of the NNet format. *) (* Skip the header part, ie comments, of the NNet format. *)
let handle_nnet_header filename in_channel = let skip_nnet_header filename in_channel =
let exception End_of_header in let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in let pos_in = ref (Stdlib.pos_in in_channel) in
try try
...@@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel = ...@@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel =
(* Retrieve number of layers, inputs, outputs and maximum layer size. *) (* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel = let handle_nnet_basic_info in_channel =
try match handle_nnet_line ~f:Int.of_string in_channel with
let line = Stdlib.input_line in_channel in | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
match handle_nnet_line ~f:Stdlib.int_of_string line with Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| [ n_layers; n_inputs; n_outputs; max_layer_size ] -> | _ -> nnet_format_error "second"
Ok (n_layers, n_inputs, n_outputs, max_layer_size) | exception End_of_file -> nnet_format_error "second"
| _ -> nnet_format_error "second"
with End_of_file -> nnet_format_error "second"
(* Retrieve size of each layer, including inputs and outputs. *) (* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel = let handle_nnet_layer_sizes n_layers in_channel =
try try
let line = Stdlib.input_line in_channel in let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in
let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
if List.length layer_sizes = n_layers + 1 then Ok layer_sizes if List.length layer_sizes = n_layers + 1 then Ok layer_sizes
else nnet_format_error "third" else nnet_format_error "third"
with End_of_file -> nnet_format_error "third" with End_of_file -> nnet_format_error "third"
...@@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel = ...@@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel =
(* Skip unused flag. *) (* Skip unused flag. *)
let handle_nnet_unused_flag in_channel = let handle_nnet_unused_flag in_channel =
try try
let _ = Stdlib.input_line in_channel in let _ = Csv.next in_channel in
Ok () Ok ()
with End_of_file -> nnet_format_error "forth" with End_of_file -> nnet_format_error "forth"
(* Retrive minimum values of inputs. *) (* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel = let handle_nnet_min_input_values n_inputs in_channel =
try try
let line = Stdlib.input_line in_channel in let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in
let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length min_input_values = n_inputs then Ok min_input_values if List.length min_input_values = n_inputs then Ok min_input_values
else nnet_format_error "fifth" else nnet_format_error "fifth"
with End_of_file -> nnet_format_error "fifth" with End_of_file -> nnet_format_error "fifth"
...@@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel = ...@@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel =
(* Retrive maximum values of inputs. *) (* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel = let handle_nnet_max_input_values n_inputs in_channel =
try try
let line = Stdlib.input_line in_channel in let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in
let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length max_input_values = n_inputs then Ok max_input_values if List.length max_input_values = n_inputs then Ok max_input_values
else nnet_format_error "sixth" else nnet_format_error "sixth"
with End_of_file -> nnet_format_error "sixth" with End_of_file -> nnet_format_error "sixth"
...@@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel = ...@@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel =
(* Retrieve mean values of inputs and one value for all outputs. *) (* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel = let handle_nnet_mean_values n_inputs in_channel =
try try
let line = Stdlib.input_line in_channel in let mean_values = handle_nnet_line ~f:Float.of_string in_channel in
let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length mean_values = n_inputs + 1 then if List.length mean_values = n_inputs + 1 then
let mean_input_values, mean_output_value = let mean_input_values, mean_output_value =
List.split_n mean_values n_inputs List.split_n mean_values n_inputs
...@@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel = ...@@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel =
(* Retrieve range values of inputs and one value for all outputs. *) (* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel = let handle_nnet_range_values n_inputs in_channel =
try try
let line = Stdlib.input_line in_channel in let range_values = handle_nnet_line ~f:Float.of_string in_channel in
let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
if List.length range_values = n_inputs + 1 then if List.length range_values = n_inputs + 1 then
let range_input_values, range_output_value = let range_input_values, range_output_value =
List.split_n range_values n_inputs List.split_n range_values n_inputs
...@@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel = ...@@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel =
else nnet_format_error "eighth" else nnet_format_error "eighth"
with End_of_file -> nnet_format_error "eighth" with End_of_file -> nnet_format_error "eighth"
(* Retrieves [filename] NNet model metadata wrt NNet format specification (see (* Retrieve all layer weights and biases as appearing in the model. No special
https://github.com/sisl/NNet for details.) *) treatment is performed. *)
let parse_metadata filename = let handle_nnet_weights_and_biases in_channel =
List.rev
(Csv.fold_left ~init:[]
~f:(fun fll sl ->
List.filter_map
~f:(fun s ->
try Some (Float.of_string (String.strip s)) with _ -> None)
sl
:: fll)
in_channel)
(* Retrieves [filename] NNet model metadata and weights wrt NNet format
specification (see https://github.com/sisl/NNet for details). *)
let parse filename =
let open Result in let open Result in
let in_channel = Stdlib.open_in filename in
try try
handle_nnet_header filename in_channel >>= fun () -> let in_channel = Stdlib.open_in filename in
skip_nnet_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () -> handle_nnet_unused_flag in_channel >>= fun () ->
...@@ -140,7 +145,8 @@ let parse_metadata filename = ...@@ -140,7 +145,8 @@ let parse_metadata filename =
handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> handle_nnet_max_input_values n_is in_channel >>= fun max_input_values ->
handle_nnet_mean_values n_is in_channel >>= fun mean_values -> handle_nnet_mean_values n_is in_channel >>= fun mean_values ->
handle_nnet_range_values n_is in_channel >>= fun range_values -> handle_nnet_range_values n_is in_channel >>= fun range_values ->
Stdlib.close_in in_channel; let weights_biases = handle_nnet_weights_and_biases in_channel in
Csv.close_in in_channel;
Ok Ok
{ {
n_layers = n_ls; n_layers = n_ls;
...@@ -152,5 +158,9 @@ let parse_metadata filename = ...@@ -152,5 +158,9 @@ let parse_metadata filename =
max_input_values; max_input_values;
mean_values; mean_values;
range_values; range_values;
weights_biases;
} }
with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) with
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
...@@ -16,9 +16,9 @@ type t = private { ...@@ -16,9 +16,9 @@ type t = private {
(** Mean values of inputs and one value for all outputs. *) (** Mean values of inputs and one value for all outputs. *)
range_values : float list * float; range_values : float list * float;
(** Range values of inputs and one value for all outputs. *) (** Range values of inputs and one value for all outputs. *)
weights_biases : float list list; (** All weights and biases of NNet model. *)
} }
[@@deriving show { with_path = false }]
(** NNet model metadata. *) (** NNet model metadata. *)
val parse_metadata : string -> (t, string) Result.t val parse : string -> (t, string) Result.t
(** Parse an NNet file for metadata. *) (** Parse an NNet file. *)
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
(name main) (name main)
(public_name caisar) (public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar) (package caisar)
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment