From 31aa8b4d51dec02064464a785977cf9823f26d3e Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Fri, 21 Oct 2022 16:43:42 +0200
Subject: [PATCH] [NIER] Add more ONNX operators to the NIER.

---
 lib/ir/nier_cfg.ml  | 12 ++++++++++++
 lib/ir/nier_cfg.mli |  6 ++++++
 lib/onnx/onnx.ml    |  6 ++++++
 3 files changed, 24 insertions(+)

diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml
index a90a168..53e32cf 100644
--- a/lib/ir/nier_cfg.ml
+++ b/lib/ir/nier_cfg.ml
@@ -125,30 +125,42 @@ module Node = struct
 
   type operator =
     | Add
+    | Sub
     | Mul
+    | Div
     | Matmul
+    | Gemm
     | LogSoftmax
     | ReLu
     | Transpose
     | Squeeze
     | MaxPool
     | Conv
+    | Reshape
+    | Flatten
     | Identity
+    | Constant
     | NO_OP
     | RW_Linearized_ReLu
 
   let str_op o =
     match o with
     | Add -> "Add"
+    | Sub -> "Sub"
     | Mul -> "Mul"
+    | Div -> "Div"
     | Matmul -> "Matmul"
+    | Gemm -> "Gemm"
     | LogSoftmax -> "LogSoftmax"
     | ReLu -> "ReLu"
     | Transpose -> "Transpose"
     | Squeeze -> "Squeeze"
     | MaxPool -> "MaxPool"
     | Conv -> "Conv"
+    | Reshape -> "Reshape"
+    | Flatten -> "Flatten"
     | Identity -> "Identity"
+    | Constant -> "Constant"
     | NO_OP -> "NO_OP"
     | RW_Linearized_ReLu -> "RW_Linearized_ReLu"
 
diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli
index fd8394a..071e7ad 100644
--- a/lib/ir/nier_cfg.mli
+++ b/lib/ir/nier_cfg.mli
@@ -80,15 +80,21 @@ module Node : sig
 
   type operator =
     | Add
+    | Sub
     | Mul
+    | Div
     | Matmul
+    | Gemm
     | LogSoftmax
     | ReLu
     | Transpose
     | Squeeze
     | MaxPool
     | Conv
+    | Reshape
+    | Flatten
     | Identity
+    | Constant
     | NO_OP
     | RW_Linearized_ReLu
 
diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml
index 6a0ea0a..6a7659a 100644
--- a/lib/onnx/onnx.ml
+++ b/lib/onnx/onnx.ml
@@ -110,14 +110,20 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
       | Some o -> (
         match o with
         | "Add" -> NCFG.Node.Add
+        | "Sub" -> NCFG.Node.Sub
         | "Mul" -> NCFG.Node.Mul
+        | "Div" -> NCFG.Node.Div
         | "Relu" -> NCFG.Node.ReLu
         | "MatMul" -> NCFG.Node.Matmul
+        | "Gemm" -> NCFG.Node.Gemm
         | "LogSoftmax" -> NCFG.Node.LogSoftmax
         | "Transpose" -> NCFG.Node.Transpose
         | "Squeeze" -> NCFG.Node.Squeeze
         | "MaxPool" -> NCFG.Node.MaxPool
+        | "Constant" -> NCFG.Node.Constant
         | "Conv" -> NCFG.Node.Conv
+        | "Reshape" -> NCFG.Node.Reshape
+        | "Flatten" -> NCFG.Node.Flatten
         | "Identity" -> NCFG.Node.Identity
         | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o)))
     in
-- 
GitLab