diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index 1887a979f343453671cee1a417115c9f36928209..d74a65d5844a67c5479be2aa6e8c45a4e7d115f9 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -32,10 +32,10 @@ type t = { n_outputs : int; max_layer_size : int; layer_sizes : int list; - min_input_values : float list; - max_input_values : float list; - mean_values : float list * float; - range_values : float list * float; + min_input_values : float list option; + max_input_values : float list option; + mean_values : (float list * float) option; + range_values : (float list * float) option; weights_biases : float list list; } @@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel = (* Retrieves [filename] NNet model metadata and weights wrt NNet format specification (see https://github.com/sisl/NNet for details). *) -let parse_in_channel filename in_channel = +let parse_in_channel ?(permissive = false) filename in_channel = let open Result in + let ok_opt r = + match r with + | Ok x -> Ok (Some x) + | Error _ as error -> if not permissive then error else Ok None + in try skip_nnet_header filename in_channel >>= fun () -> let in_channel = Csv.of_channel in_channel in handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_unused_flag in_channel >>= fun () -> - handle_nnet_min_input_values n_is in_channel >>= fun min_input_values -> - handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> - handle_nnet_mean_values n_is in_channel >>= fun mean_values -> - handle_nnet_range_values n_is in_channel >>= fun range_values -> + ok_opt (handle_nnet_min_input_values n_is in_channel) + >>= fun min_input_values -> + ok_opt (handle_nnet_max_input_values n_is in_channel) + >>= fun max_input_values -> + ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values -> + ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values -> let weights_biases = handle_nnet_weights_and_biases in_channel in Csv.close_in in_channel; Ok @@ -184,10 +191,10 @@ let parse_in_channel filename in_channel = with | Csv.Failure (_nrecord, _nfield, msg) -> Error msg | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) -let parse filename = +let parse ?(permissive = false) filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) - (fun () -> parse_in_channel filename in_channel) + (fun () -> parse_in_channel ~permissive filename in_channel) diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli index be73f684f19fc326a64f0ff9a40fafe66e02ba78..4e834f58f3a3624f86fc629ec71ba1cae45dd7f9 100644 --- a/lib/nnet/nnet.mli +++ b/lib/nnet/nnet.mli @@ -26,15 +26,19 @@ type t = private { n_outputs : int; (** Number of outputs. *) max_layer_size : int; (** Maximum layer size. *) layer_sizes : int list; (** Size of each layer. *) - min_input_values : float list; (** Minimum values of inputs. *) - max_input_values : float list; (** Maximum values of inputs. *) - mean_values : float list * float; + min_input_values : float list option; (** Minimum values of inputs. *) + max_input_values : float list option; (** Maximum values of inputs. *) + mean_values : (float list * float) option; (** Mean values of inputs and one value for all outputs. *) - range_values : float list * float; + range_values : (float list * float) option; (** Range values of inputs and one value for all outputs. *) weights_biases : float list list; (** All weights and biases of NNet model. *) } (** NNet model metadata. *) -val parse : string -> (t, string) Result.t -(** Parse an NNet file. *) +val parse : ?permissive:bool -> string -> (t, string) Result.t +(** Parse an NNet file. + + @param permissive + [false] by default. When set, parsing does not fail on non available + information, which are set to [None] instead. *) diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 3a8a8794799647e06fa37d4c5da919a52221c023..6a0ea0a963cc78954c74c3141c492c5bee87bf6c 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -33,6 +33,7 @@ exception ParseError of string type t = { n_inputs : int; (* Number of inputs. *) n_outputs : int; (* Number of outputs. *) + nier : (G.t, string) Result.t; (* Intermediate representation. *) } (* ONNX format handling. *) @@ -76,14 +77,9 @@ let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) = | `not_set -> acc) let get_input_output_dim (model : Oprotom.t) = - let ins, outs = - match model.graph with - | Some g -> (Some g.input, Some g.output) - | None -> (None, None) - in let input_shape, output_shape = - match (ins, outs) with - | Some i, Some o -> (get_nested_dims i, get_nested_dims o) + match model.graph with + | Some g -> (get_nested_dims g.input, get_nested_dims g.output) | _ -> ([], []) in (* TODO: here we only get the flattened dimension of inputs and outputs, but @@ -123,8 +119,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | "MaxPool" -> NCFG.Node.MaxPool | "Conv" -> NCFG.Node.Conv | "Identity" -> NCFG.Node.Identity - | _ -> - raise (ParseError ("Unsupported ONNX Operator in\n Parser: " ^ o))) + | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o))) in List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns in @@ -218,7 +213,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = let unpack v = match v with | Some v -> v - | None -> failwith "error, unpack found an unexpected None" + | None -> failwith "Unpack found an unexpected None" in let tensor_list = List.init @@ -242,7 +237,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | `not_set -> failwith "No tensor type in value info" (* TODO: support more tensor types *) - | _ -> raise (ParseError "Unknown tensor type.") + | _ -> raise (ParseError "Unknown tensor type") in let tns_s = match tns_t.shape with @@ -290,9 +285,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = (*All other list constructions are folding right, so we need to put a final revert *) | _ -> - raise - (ParseError - "Error, operators and attributes list have not\n the same size") + raise (ParseError "Operator and attribute lists have not the same size") in let op_params_cfg = build_op_param_list attrs ops [] in let cfg = G.init_cfg in @@ -500,7 +493,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = let nier_of_onnx_protoc (model : Oprotom.t) = match model.graph with | Some g -> produce_cfg g - | None -> raise (ParseError "No graph in ONNX input file!") + | None -> raise (ParseError "No graph in ONNX input file found") let parse_in_channel in_channel = let open Result in @@ -510,12 +503,16 @@ let parse_in_channel in_channel = match Oprotom.from_proto reader with | Ok r -> let n_inputs, n_outputs = get_input_output_dim r in - let nier = nier_of_onnx_protoc r in - Ok ({ n_inputs; n_outputs }, nier) - | _ -> Error "Error parsing protobuf" + let nier = + try Ok (nier_of_onnx_protoc r) with + | ParseError s | Sys_error s -> Error s + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) + in + Ok { n_inputs; n_outputs; nier } + | _ -> Error "Cannot read protobuf" with | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) let parse filename = let in_channel = Stdlib.open_in filename in diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index 7eb5500ccd950a59218ed5e7f3c875bbc1f923cf..0e946c847a5fb195879acc9f5880716ca5687d6a 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -25,10 +25,9 @@ module G = Ir.Nier_cfg.NierCFGFloat type t = private { n_inputs : int; (** Number of inputs. *) n_outputs : int; (** Number of outputs. *) + nier : (G.t, string) Result.t; (** Intermediate representation. *) } (** ONNX model metadata. *) -(** Parse an ONNX file to get metadata for CAISAR as well as its inner - intermediate representation for the network. *) - -val parse : string -> (t * G.t, string) Result.t +val parse : string -> (t, string) Result.t +(** Parse an ONNX file. *) diff --git a/src/language.ml b/src/language.ml index 6cae58d56283c06ad088a1d7f5ea71a416270b44..35cde2dbe6e0eb3fc4b2cb4d19bee8e57680333a 100644 --- a/src/language.ml +++ b/src/language.ml @@ -58,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename nier env = (Ty.ty_tuple (List.init nb_outputs ~f)) in Term.Hls.add loaded_nets ls_net_apply - { filename; nb_inputs; nb_outputs; ty_data = input_type; nier }; + { filename; nb_inputs; nb_outputs; ty_data = input_type; nier }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) @@ -84,25 +84,33 @@ let register_svm_as_array nb_inputs nb_classes filename env = Wstdlib.Mstr.singleton "SVMasArray" (Pmodule.close_module th_uc) let nnet_parser env _ filename _ = - let model = Nnet.parse filename in + let model = Nnet.parse ~permissive:true filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> - register_nn_as_tuple model.n_inputs model.n_outputs filename None env + | Ok { n_inputs; n_outputs; _ } -> + register_nn_as_tuple n_inputs n_outputs filename None env let onnx_parser env _ filename _ = let model = Onnx.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok (model, nier) -> - register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env + | Ok { n_inputs; n_outputs; nier } -> + let nier = + match nier with + | Error msg -> + Logs.warn (fun m -> + m "Cannot build network intermediate representation:@ %s" msg); + None + | Ok nier -> Some nier + in + register_nn_as_tuple n_inputs n_outputs filename nier env let ovo_parser env _ filename _ = let model = Ovo.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> - register_svm_as_array model.n_inputs model.n_outputs filename env + | Ok { n_inputs; n_outputs } -> + register_svm_as_array n_inputs n_outputs filename env let register_nnet_support () = Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language