From 51f295dea195bf1e7dcbf536691e4d04ee92c8e3 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Fri, 12 Jan 2024 16:03:46 +0100
Subject: [PATCH] Basic output of NIER structure, data and attributes.

This commit adds support for outputting most of the NIER informations into ONNX format.
Some metadata may be missing, such as input and output nodes for the ONNX graph.
---
 lib/onnx/onnx.ml | 107 ++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 87 insertions(+), 20 deletions(-)

diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml
index a3b631e..170d74d 100644
--- a/lib/onnx/onnx.ml
+++ b/lib/onnx/onnx.ml
@@ -505,42 +505,109 @@ let nier_of_onnx_protoc (model : Oprotom.t) =
   | None -> raise (ParseError "No graph in ONNX input file found")
 
 let nier_to_onnx_protoc nier =
-  (* TODO: write a simple ONNX model from a dummy NIER *)
+  (* TODO: get tensor data, and operator params *)
   let vertices = G.vertex_list nier in
+  let open NCFG.Node in
   let protocs =
+    (* match on names of NO_OP nodes and add their outputs to corresponding
+     * C_NODEs inputs *)
     let vertex_to_protoc v =
-      let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in
-      let name = NCFG.Node.get_name v in
-      let domain = "" in
-      let input, output =
-        (NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v)
+      let name = get_name v in
+      let input, output = (get_pred_list v, get_succ_list v) in
+      let node, initi =
+        match get_op v with
+        | NO_OP | RW_Linearized_ReLu ->
+          (* ONNX initializers are named ONNX Tensor.
+           * If an initializer's name matches an existing
+           * ONNX node input name, the initializer will be assigned as
+           * the input of the node. *)
+          let initi =
+            match get_tensor v with
+            | None -> None
+            | Some t ->
+              Some
+                (Oproto.Onnx.TensorProto.make ~data_type:1
+                   ~dims:(Array.to_list @@ NCFG.Tensor.get_shape t)
+                   ~float_data:(NCFG.Tensor.flatten t) ~name ())
+          in
+          let node = None in
+          (node, initi)
+        | _ ->
+          let op_type = str_op (get_op v) in
+          let attribute =
+            match v.operator_parameters with
+            | None | Some (RW_Linearized_ReLu_params _) -> []
+            | Some
+                (Pool_params
+                  (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
+            | Some
+                (Conv_params
+                  (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
+              ->
+              let ksize =
+                Oproto.Onnx.AttributeProto.make ~name:"ksize"
+                  ~ints:(Array.to_list k) ()
+              in
+              let stride =
+                Oproto.Onnx.AttributeProto.make ~name:"stride"
+                  ~ints:(Array.to_list s) ()
+              in
+              let pads =
+                Oproto.Onnx.AttributeProto.make ~name:"pads"
+                  ~ints:(Array.to_list p) ()
+              in
+              let dilations =
+                Oproto.Onnx.AttributeProto.make ~name:"dilations"
+                  ~ints:(Array.to_list d) ()
+              in
+              [ ksize; stride; pads; dilations ]
+            | Some (Transpose_params s) ->
+              [
+                Oproto.Onnx.AttributeProto.make ~name:"perms"
+                  ~ints:(Array.to_list s) ();
+              ]
+            | _ -> []
+          in
+          let node =
+            Some
+              (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type
+                 ~attribute ~doc_string:"" ())
+          in
+          let initi = None in
+          (node, initi)
       in
-      Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain
-        ~attribute:[] ~doc_string:"" ()
+      (node, initi)
     in
-    List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices
+    List.fold ~init:([], [])
+      ~f:(fun (accn, acci) v ->
+        let node, initi = vertex_to_protoc v in
+        match (node, initi) with
+        | Some n, Some t -> (n :: accn, t :: acci)
+        | Some n, None -> (n :: accn, acci)
+        | None, Some t -> (accn, t :: acci)
+        | None, None -> (accn, acci))
+      vertices
+  in
+  let docstr =
+    "This ONNX model was generated from the Neural Intermediate Representation \
+     of CAISAR"
   in
   let protog =
-    Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[]
-      ~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[]
-      ~quantization_annotation:[] ()
+    Oproto.Onnx.GraphProto.make ~name:"ONNX CAISAR Export" ~node:(fst protocs)
+      ~initializer':(snd protocs) ~sparse_initializer:[]
+      ~doc_string:"ONNX graph generated from CAISAR NIER" ~input:[] ~output:[]
+      ~value_info:[] ~quantization_annotation:[] ()
   in
   let protom =
     Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[]
       ~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:""
-      ~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[]
+      ~model_version:(-1) ~doc_string:docstr ~graph:protog ~metadata_props:[]
       ~training_info:[] ~functions:[] ()
   in
   let writer = Oprotom.to_proto protom in
   Ocaml_protoc_plugin.Writer.contents writer
 
-let write_nier_to_onnx _nier out_channel =
-  let nier = G.init_cfg in
-  let n =
-    Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP
-      ~op_p:None ~pred:[] ~succ:[] ~tensor:None
-  in
-  G.add_vertex nier n;
+let write_nier_to_onnx nier out_channel =
   let onnx = nier_to_onnx_protoc nier in
   Stdio.Out_channel.output_string out_channel onnx
 
-- 
GitLab