Skip to content
Snippets Groups Projects
Commit 51f295de authored by Julien Girard-Satabin's avatar Julien Girard-Satabin Committed by Michele Alberti
Browse files

Basic output of NIER structure, data and attributes.

This commit adds support for outputting most of the NIER informations into ONNX format.
Some metadata may be missing, such as input and output nodes for the ONNX graph.
parent b0002177
No related branches found
No related tags found
No related merge requests found
...@@ -505,42 +505,109 @@ let nier_of_onnx_protoc (model : Oprotom.t) = ...@@ -505,42 +505,109 @@ let nier_of_onnx_protoc (model : Oprotom.t) =
| None -> raise (ParseError "No graph in ONNX input file found") | None -> raise (ParseError "No graph in ONNX input file found")
let nier_to_onnx_protoc nier = let nier_to_onnx_protoc nier =
(* TODO: write a simple ONNX model from a dummy NIER *) (* TODO: get tensor data, and operator params *)
let vertices = G.vertex_list nier in let vertices = G.vertex_list nier in
let open NCFG.Node in
let protocs = let protocs =
(* match on names of NO_OP nodes and add their outputs to corresponding
* C_NODEs inputs *)
let vertex_to_protoc v = let vertex_to_protoc v =
let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in let name = get_name v in
let name = NCFG.Node.get_name v in let input, output = (get_pred_list v, get_succ_list v) in
let domain = "" in let node, initi =
let input, output = match get_op v with
(NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v) | NO_OP | RW_Linearized_ReLu ->
(* ONNX initializers are named ONNX Tensor.
* If an initializer's name matches an existing
* ONNX node input name, the initializer will be assigned as
* the input of the node. *)
let initi =
match get_tensor v with
| None -> None
| Some t ->
Some
(Oproto.Onnx.TensorProto.make ~data_type:1
~dims:(Array.to_list @@ NCFG.Tensor.get_shape t)
~float_data:(NCFG.Tensor.flatten t) ~name ())
in
let node = None in
(node, initi)
| _ ->
let op_type = str_op (get_op v) in
let attribute =
match v.operator_parameters with
| None | Some (RW_Linearized_ReLu_params _) -> []
| Some
(Pool_params
(Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
| Some
(Conv_params
(Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
->
let ksize =
Oproto.Onnx.AttributeProto.make ~name:"ksize"
~ints:(Array.to_list k) ()
in
let stride =
Oproto.Onnx.AttributeProto.make ~name:"stride"
~ints:(Array.to_list s) ()
in
let pads =
Oproto.Onnx.AttributeProto.make ~name:"pads"
~ints:(Array.to_list p) ()
in
let dilations =
Oproto.Onnx.AttributeProto.make ~name:"dilations"
~ints:(Array.to_list d) ()
in
[ ksize; stride; pads; dilations ]
| Some (Transpose_params s) ->
[
Oproto.Onnx.AttributeProto.make ~name:"perms"
~ints:(Array.to_list s) ();
]
| _ -> []
in
let node =
Some
(Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type
~attribute ~doc_string:"" ())
in
let initi = None in
(node, initi)
in in
Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain (node, initi)
~attribute:[] ~doc_string:"" ()
in in
List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices List.fold ~init:([], [])
~f:(fun (accn, acci) v ->
let node, initi = vertex_to_protoc v in
match (node, initi) with
| Some n, Some t -> (n :: accn, t :: acci)
| Some n, None -> (n :: accn, acci)
| None, Some t -> (accn, t :: acci)
| None, None -> (accn, acci))
vertices
in
let docstr =
"This ONNX model was generated from the Neural Intermediate Representation \
of CAISAR"
in in
let protog = let protog =
Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[] Oproto.Onnx.GraphProto.make ~name:"ONNX CAISAR Export" ~node:(fst protocs)
~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[] ~initializer':(snd protocs) ~sparse_initializer:[]
~quantization_annotation:[] () ~doc_string:"ONNX graph generated from CAISAR NIER" ~input:[] ~output:[]
~value_info:[] ~quantization_annotation:[] ()
in in
let protom = let protom =
Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[] Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[]
~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:"" ~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:""
~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[] ~model_version:(-1) ~doc_string:docstr ~graph:protog ~metadata_props:[]
~training_info:[] ~functions:[] () ~training_info:[] ~functions:[] ()
in in
let writer = Oprotom.to_proto protom in let writer = Oprotom.to_proto protom in
Ocaml_protoc_plugin.Writer.contents writer Ocaml_protoc_plugin.Writer.contents writer
let write_nier_to_onnx _nier out_channel = let write_nier_to_onnx nier out_channel =
let nier = G.init_cfg in
let n =
Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP
~op_p:None ~pred:[] ~succ:[] ~tensor:None
in
G.add_vertex nier n;
let onnx = nier_to_onnx_protoc nier in let onnx = nier_to_onnx_protoc nier in
Stdio.Out_channel.output_string out_channel onnx Stdio.Out_channel.output_string out_channel onnx
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment