Skip to content
Snippets Groups Projects
Commit 84025c4b authored by Michele Alberti's avatar Michele Alberti
Browse files

[nnet] Admit parsing failure on non-mandatory information.

parent 3c965d10
No related branches found
No related tags found
No related merge requests found
...@@ -32,10 +32,10 @@ type t = { ...@@ -32,10 +32,10 @@ type t = {
n_outputs : int; n_outputs : int;
max_layer_size : int; max_layer_size : int;
layer_sizes : int list; layer_sizes : int list;
min_input_values : float list; min_input_values : float list option;
max_input_values : float list; max_input_values : float list option;
mean_values : float list * float; mean_values : (float list * float) option;
range_values : float list * float; range_values : (float list * float) option;
weights_biases : float list list; weights_biases : float list list;
} }
...@@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel = ...@@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel =
(* Retrieves [filename] NNet model metadata and weights wrt NNet format (* Retrieves [filename] NNet model metadata and weights wrt NNet format
specification (see https://github.com/sisl/NNet for details). *) specification (see https://github.com/sisl/NNet for details). *)
let parse_in_channel filename in_channel = let parse_in_channel ?(permissive = false) filename in_channel =
let open Result in let open Result in
let ok_opt r =
match r with
| Ok x -> Ok (Some x)
| Error _ as error -> if not permissive then error else Ok None
in
try try
skip_nnet_header filename in_channel >>= fun () -> skip_nnet_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in let in_channel = Csv.of_channel in_channel in
handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () -> handle_nnet_unused_flag in_channel >>= fun () ->
handle_nnet_min_input_values n_is in_channel >>= fun min_input_values -> ok_opt (handle_nnet_min_input_values n_is in_channel)
handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> >>= fun min_input_values ->
handle_nnet_mean_values n_is in_channel >>= fun mean_values -> ok_opt (handle_nnet_max_input_values n_is in_channel)
handle_nnet_range_values n_is in_channel >>= fun range_values -> >>= fun max_input_values ->
ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values ->
ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values ->
let weights_biases = handle_nnet_weights_and_biases in_channel in let weights_biases = handle_nnet_weights_and_biases in_channel in
Csv.close_in in_channel; Csv.close_in in_channel;
Ok Ok
...@@ -184,10 +191,10 @@ let parse_in_channel filename in_channel = ...@@ -184,10 +191,10 @@ let parse_in_channel filename in_channel =
with with
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg | Csv.Failure (_nrecord, _nfield, msg) -> Error msg
| Sys_error s -> Error s | Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
let parse filename = let parse ?(permissive = false) filename =
let in_channel = Stdlib.open_in filename in let in_channel = Stdlib.open_in filename in
Fun.protect Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel) ~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel filename in_channel) (fun () -> parse_in_channel ~permissive filename in_channel)
...@@ -26,15 +26,19 @@ type t = private { ...@@ -26,15 +26,19 @@ type t = private {
n_outputs : int; (** Number of outputs. *) n_outputs : int; (** Number of outputs. *)
max_layer_size : int; (** Maximum layer size. *) max_layer_size : int; (** Maximum layer size. *)
layer_sizes : int list; (** Size of each layer. *) layer_sizes : int list; (** Size of each layer. *)
min_input_values : float list; (** Minimum values of inputs. *) min_input_values : float list option; (** Minimum values of inputs. *)
max_input_values : float list; (** Maximum values of inputs. *) max_input_values : float list option; (** Maximum values of inputs. *)
mean_values : float list * float; mean_values : (float list * float) option;
(** Mean values of inputs and one value for all outputs. *) (** Mean values of inputs and one value for all outputs. *)
range_values : float list * float; range_values : (float list * float) option;
(** Range values of inputs and one value for all outputs. *) (** Range values of inputs and one value for all outputs. *)
weights_biases : float list list; (** All weights and biases of NNet model. *) weights_biases : float list list; (** All weights and biases of NNet model. *)
} }
(** NNet model metadata. *) (** NNet model metadata. *)
val parse : string -> (t, string) Result.t val parse : ?permissive:bool -> string -> (t, string) Result.t
(** Parse an NNet file. *) (** Parse an NNet file.
@param permissive
[false] by default. When set, parsing does not fail on non available
information, which are set to [None] instead. *)
...@@ -28,7 +28,6 @@ type t = private { ...@@ -28,7 +28,6 @@ type t = private {
} }
(** ONNX model metadata. *) (** ONNX model metadata. *)
val parse : string -> (t * G.t, string) Result.t
(** Parse an ONNX file to get metadata for CAISAR as well as its inner (** Parse an ONNX file to get metadata for CAISAR as well as its inner
intermediate representation for the network. *) intermediate representation for the network. *)
val parse : string -> (t * G.t, string) Result.t
...@@ -84,7 +84,7 @@ let register_svm_as_array nb_inputs nb_classes filename env = ...@@ -84,7 +84,7 @@ let register_svm_as_array nb_inputs nb_classes filename env =
Wstdlib.Mstr.singleton "SVMasArray" (Pmodule.close_module th_uc) Wstdlib.Mstr.singleton "SVMasArray" (Pmodule.close_module th_uc)
let nnet_parser env _ filename _ = let nnet_parser env _ filename _ =
let model = Nnet.parse filename in let model = Nnet.parse ~permissive:true filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok model -> | Ok model ->
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment