diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index a3b631ee497baa0b74be16cf02c76cea25846f3e..170d74de67e2ac5d45cbf46e65167ee8b36e9b43 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -505,42 +505,109 @@ let nier_of_onnx_protoc (model : Oprotom.t) = | None -> raise (ParseError "No graph in ONNX input file found") let nier_to_onnx_protoc nier = - (* TODO: write a simple ONNX model from a dummy NIER *) + (* TODO: get tensor data, and operator params *) let vertices = G.vertex_list nier in + let open NCFG.Node in let protocs = + (* match on names of NO_OP nodes and add their outputs to corresponding + * C_NODEs inputs *) let vertex_to_protoc v = - let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in - let name = NCFG.Node.get_name v in - let domain = "" in - let input, output = - (NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v) + let name = get_name v in + let input, output = (get_pred_list v, get_succ_list v) in + let node, initi = + match get_op v with + | NO_OP | RW_Linearized_ReLu -> + (* ONNX initializers are named ONNX Tensor. + * If an initializer's name matches an existing + * ONNX node input name, the initializer will be assigned as + * the input of the node. *) + let initi = + match get_tensor v with + | None -> None + | Some t -> + Some + (Oproto.Onnx.TensorProto.make ~data_type:1 + ~dims:(Array.to_list @@ NCFG.Tensor.get_shape t) + ~float_data:(NCFG.Tensor.flatten t) ~name ()) + in + let node = None in + (node, initi) + | _ -> + let op_type = str_op (get_op v) in + let attribute = + match v.operator_parameters with + | None | Some (RW_Linearized_ReLu_params _) -> [] + | Some + (Pool_params + (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d))) + | Some + (Conv_params + (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d))) + -> + let ksize = + Oproto.Onnx.AttributeProto.make ~name:"ksize" + ~ints:(Array.to_list k) () + in + let stride = + Oproto.Onnx.AttributeProto.make ~name:"stride" + ~ints:(Array.to_list s) () + in + let pads = + Oproto.Onnx.AttributeProto.make ~name:"pads" + ~ints:(Array.to_list p) () + in + let dilations = + Oproto.Onnx.AttributeProto.make ~name:"dilations" + ~ints:(Array.to_list d) () + in + [ ksize; stride; pads; dilations ] + | Some (Transpose_params s) -> + [ + Oproto.Onnx.AttributeProto.make ~name:"perms" + ~ints:(Array.to_list s) (); + ] + | _ -> [] + in + let node = + Some + (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type + ~attribute ~doc_string:"" ()) + in + let initi = None in + (node, initi) in - Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain - ~attribute:[] ~doc_string:"" () + (node, initi) in - List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices + List.fold ~init:([], []) + ~f:(fun (accn, acci) v -> + let node, initi = vertex_to_protoc v in + match (node, initi) with + | Some n, Some t -> (n :: accn, t :: acci) + | Some n, None -> (n :: accn, acci) + | None, Some t -> (accn, t :: acci) + | None, None -> (accn, acci)) + vertices + in + let docstr = + "This ONNX model was generated from the Neural Intermediate Representation \ + of CAISAR" in let protog = - Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[] - ~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[] - ~quantization_annotation:[] () + Oproto.Onnx.GraphProto.make ~name:"ONNX CAISAR Export" ~node:(fst protocs) + ~initializer':(snd protocs) ~sparse_initializer:[] + ~doc_string:"ONNX graph generated from CAISAR NIER" ~input:[] ~output:[] + ~value_info:[] ~quantization_annotation:[] () in let protom = Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[] ~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:"" - ~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[] + ~model_version:(-1) ~doc_string:docstr ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] () in let writer = Oprotom.to_proto protom in Ocaml_protoc_plugin.Writer.contents writer -let write_nier_to_onnx _nier out_channel = - let nier = G.init_cfg in - let n = - Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP - ~op_p:None ~pred:[] ~succ:[] ~tensor:None - in - G.add_vertex nier n; +let write_nier_to_onnx nier out_channel = let onnx = nier_to_onnx_protoc nier in Stdio.Out_channel.output_string out_channel onnx