diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index 1887a979f343453671cee1a417115c9f36928209..d74a65d5844a67c5479be2aa6e8c45a4e7d115f9 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -32,10 +32,10 @@ type t = { n_outputs : int; max_layer_size : int; layer_sizes : int list; - min_input_values : float list; - max_input_values : float list; - mean_values : float list * float; - range_values : float list * float; + min_input_values : float list option; + max_input_values : float list option; + mean_values : (float list * float) option; + range_values : (float list * float) option; weights_biases : float list list; } @@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel = (* Retrieves [filename] NNet model metadata and weights wrt NNet format specification (see https://github.com/sisl/NNet for details). *) -let parse_in_channel filename in_channel = +let parse_in_channel ?(permissive = false) filename in_channel = let open Result in + let ok_opt r = + match r with + | Ok x -> Ok (Some x) + | Error _ as error -> if not permissive then error else Ok None + in try skip_nnet_header filename in_channel >>= fun () -> let in_channel = Csv.of_channel in_channel in handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_unused_flag in_channel >>= fun () -> - handle_nnet_min_input_values n_is in_channel >>= fun min_input_values -> - handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> - handle_nnet_mean_values n_is in_channel >>= fun mean_values -> - handle_nnet_range_values n_is in_channel >>= fun range_values -> + ok_opt (handle_nnet_min_input_values n_is in_channel) + >>= fun min_input_values -> + ok_opt (handle_nnet_max_input_values n_is in_channel) + >>= fun max_input_values -> + ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values -> + ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values -> let weights_biases = handle_nnet_weights_and_biases in_channel in Csv.close_in in_channel; Ok @@ -184,10 +191,10 @@ let parse_in_channel filename in_channel = with | Csv.Failure (_nrecord, _nfield, msg) -> Error msg | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) -let parse filename = +let parse ?(permissive = false) filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) - (fun () -> parse_in_channel filename in_channel) + (fun () -> parse_in_channel ~permissive filename in_channel) diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli index be73f684f19fc326a64f0ff9a40fafe66e02ba78..4e834f58f3a3624f86fc629ec71ba1cae45dd7f9 100644 --- a/lib/nnet/nnet.mli +++ b/lib/nnet/nnet.mli @@ -26,15 +26,19 @@ type t = private { n_outputs : int; (** Number of outputs. *) max_layer_size : int; (** Maximum layer size. *) layer_sizes : int list; (** Size of each layer. *) - min_input_values : float list; (** Minimum values of inputs. *) - max_input_values : float list; (** Maximum values of inputs. *) - mean_values : float list * float; + min_input_values : float list option; (** Minimum values of inputs. *) + max_input_values : float list option; (** Maximum values of inputs. *) + mean_values : (float list * float) option; (** Mean values of inputs and one value for all outputs. *) - range_values : float list * float; + range_values : (float list * float) option; (** Range values of inputs and one value for all outputs. *) weights_biases : float list list; (** All weights and biases of NNet model. *) } (** NNet model metadata. *) -val parse : string -> (t, string) Result.t -(** Parse an NNet file. *) +val parse : ?permissive:bool -> string -> (t, string) Result.t +(** Parse an NNet file. + + @param permissive + [false] by default. When set, parsing does not fail on non available + information, which are set to [None] instead. *) diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index 7eb5500ccd950a59218ed5e7f3c875bbc1f923cf..04c6f96f2275ea0a7076529406993929453cb416 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -28,7 +28,6 @@ type t = private { } (** ONNX model metadata. *) +val parse : string -> (t * G.t, string) Result.t (** Parse an ONNX file to get metadata for CAISAR as well as its inner intermediate representation for the network. *) - -val parse : string -> (t * G.t, string) Result.t diff --git a/src/language.ml b/src/language.ml index 6cae58d56283c06ad088a1d7f5ea71a416270b44..c9d04dd2c3fae5c2423a286cb8f03bc4e02f5898 100644 --- a/src/language.ml +++ b/src/language.ml @@ -84,7 +84,7 @@ let register_svm_as_array nb_inputs nb_classes filename env = Wstdlib.Mstr.singleton "SVMasArray" (Pmodule.close_module th_uc) let nnet_parser env _ filename _ = - let model = Nnet.parse filename in + let model = Nnet.parse ~permissive:true filename in match model with | Error s -> Loc.errorm "%s" s | Ok model ->