diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 3a8a8794799647e06fa37d4c5da919a52221c023..5dd4c592289f0cb86235ad7ea7d64bdfb14db9d9 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -33,6 +33,7 @@ exception ParseError of string type t = { n_inputs : int; (* Number of inputs. *) n_outputs : int; (* Number of outputs. *) + nier : (G.t, string) Result.t; (* Intermediate representation. *) } (* ONNX format handling. *) @@ -76,14 +77,9 @@ let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) = | `not_set -> acc) let get_input_output_dim (model : Oprotom.t) = - let ins, outs = - match model.graph with - | Some g -> (Some g.input, Some g.output) - | None -> (None, None) - in let input_shape, output_shape = - match (ins, outs) with - | Some i, Some o -> (get_nested_dims i, get_nested_dims o) + match model.graph with + | Some g -> (get_nested_dims g.input, get_nested_dims g.output) | _ -> ([], []) in (* TODO: here we only get the flattened dimension of inputs and outputs, but @@ -123,8 +119,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | "MaxPool" -> NCFG.Node.MaxPool | "Conv" -> NCFG.Node.Conv | "Identity" -> NCFG.Node.Identity - | _ -> - raise (ParseError ("Unsupported ONNX Operator in\n Parser: " ^ o))) + | _ -> raise (ParseError ("Unsupported ONNX operator: " ^ o))) in List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns in @@ -218,7 +213,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = let unpack v = match v with | Some v -> v - | None -> failwith "error, unpack found an unexpected None" + | None -> failwith "Unpack found an unexpected None" in let tensor_list = List.init @@ -242,7 +237,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | `not_set -> failwith "No tensor type in value info" (* TODO: support more tensor types *) - | _ -> raise (ParseError "Unknown tensor type.") + | _ -> raise (ParseError "Unknown tensor type") in let tns_s = match tns_t.shape with @@ -290,9 +285,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = (*All other list constructions are folding right, so we need to put a final revert *) | _ -> - raise - (ParseError - "Error, operators and attributes list have not\n the same size") + raise (ParseError "Operator and attribute lists have not the same size") in let op_params_cfg = build_op_param_list attrs ops [] in let cfg = G.init_cfg in @@ -500,7 +493,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = let nier_of_onnx_protoc (model : Oprotom.t) = match model.graph with | Some g -> produce_cfg g - | None -> raise (ParseError "No graph in ONNX input file!") + | None -> raise (ParseError "No graph in ONNX input file found") let parse_in_channel in_channel = let open Result in @@ -510,12 +503,16 @@ let parse_in_channel in_channel = match Oprotom.from_proto reader with | Ok r -> let n_inputs, n_outputs = get_input_output_dim r in - let nier = nier_of_onnx_protoc r in - Ok ({ n_inputs; n_outputs }, nier) - | _ -> Error "Error parsing protobuf" + let nier = + try Ok (nier_of_onnx_protoc r) with + | ParseError s | Sys_error s -> Error s + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) + in + Ok { n_inputs; n_outputs; nier } + | _ -> Error "Cannot read protobuf" with | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) let parse filename = let in_channel = Stdlib.open_in filename in diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index 04c6f96f2275ea0a7076529406993929453cb416..0e946c847a5fb195879acc9f5880716ca5687d6a 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -25,9 +25,9 @@ module G = Ir.Nier_cfg.NierCFGFloat type t = private { n_inputs : int; (** Number of inputs. *) n_outputs : int; (** Number of outputs. *) + nier : (G.t, string) Result.t; (** Intermediate representation. *) } (** ONNX model metadata. *) -val parse : string -> (t * G.t, string) Result.t -(** Parse an ONNX file to get metadata for CAISAR as well as its inner - intermediate representation for the network. *) +val parse : string -> (t, string) Result.t +(** Parse an ONNX file. *)