Skip to content
Snippets Groups Projects
Commit 3d4ccfa1 authored by Michele Alberti's avatar Michele Alberti
Browse files

[ONNX] Make nier a field of the parsing result. Keep error messages for future use.

parent 84025c4b
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ exception ParseError of string
type t = {
n_inputs : int; (* Number of inputs. *)
n_outputs : int; (* Number of outputs. *)
nier : (G.t, string) Result.t; (* Intermediate representation. *)
}
(* ONNX format handling. *)
......@@ -76,14 +77,9 @@ let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
| `not_set -> acc)
let get_input_output_dim (model : Oprotom.t) =
let ins, outs =
match model.graph with
| Some g -> (Some g.input, Some g.output)
| None -> (None, None)
in
let input_shape, output_shape =
match (ins, outs) with
| Some i, Some o -> (get_nested_dims i, get_nested_dims o)
match model.graph with
| Some g -> (get_nested_dims g.input, get_nested_dims g.output)
| _ -> ([], [])
in
(* TODO: here we only get the flattened dimension of inputs and outputs, but
......@@ -123,8 +119,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
| "MaxPool" -> NCFG.Node.MaxPool
| "Conv" -> NCFG.Node.Conv
| "Identity" -> NCFG.Node.Identity
| _ ->
raise (ParseError ("Unsupported ONNX Operator in\n Parser: " ^ o)))
| _ -> raise (ParseError ("Unsupported ONNX operator: " ^ o)))
in
List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns
in
......@@ -218,7 +213,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
let unpack v =
match v with
| Some v -> v
| None -> failwith "error, unpack found an unexpected None"
| None -> failwith "Unpack found an unexpected None"
in
let tensor_list =
List.init
......@@ -242,7 +237,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
| `not_set ->
failwith "No tensor type in value info"
(* TODO: support more tensor types *)
| _ -> raise (ParseError "Unknown tensor type.")
| _ -> raise (ParseError "Unknown tensor type")
in
let tns_s =
match tns_t.shape with
......@@ -290,9 +285,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
(*All other list constructions are folding right, so we need to put a
final revert *)
| _ ->
raise
(ParseError
"Error, operators and attributes list have not\n the same size")
raise (ParseError "Operator and attribute lists have not the same size")
in
let op_params_cfg = build_op_param_list attrs ops [] in
let cfg = G.init_cfg in
......@@ -500,7 +493,7 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
let nier_of_onnx_protoc (model : Oprotom.t) =
match model.graph with
| Some g -> produce_cfg g
| None -> raise (ParseError "No graph in ONNX input file!")
| None -> raise (ParseError "No graph in ONNX input file found")
let parse_in_channel in_channel =
let open Result in
......@@ -510,12 +503,16 @@ let parse_in_channel in_channel =
match Oprotom.from_proto reader with
| Ok r ->
let n_inputs, n_outputs = get_input_output_dim r in
let nier = nier_of_onnx_protoc r in
Ok ({ n_inputs; n_outputs }, nier)
| _ -> Error "Error parsing protobuf"
let nier =
try Ok (nier_of_onnx_protoc r) with
| ParseError s | Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
in
Ok { n_inputs; n_outputs; nier }
| _ -> Error "Cannot read protobuf"
with
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
| Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
let parse filename =
let in_channel = Stdlib.open_in filename in
......
......@@ -25,9 +25,9 @@ module G = Ir.Nier_cfg.NierCFGFloat
type t = private {
n_inputs : int; (** Number of inputs. *)
n_outputs : int; (** Number of outputs. *)
nier : (G.t, string) Result.t; (** Intermediate representation. *)
}
(** ONNX model metadata. *)
val parse : string -> (t * G.t, string) Result.t
(** Parse an ONNX file to get metadata for CAISAR as well as its inner
intermediate representation for the network. *)
val parse : string -> (t, string) Result.t
(** Parse an ONNX file. *)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment