From 31aa8b4d51dec02064464a785977cf9823f26d3e Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Fri, 21 Oct 2022 16:43:42 +0200 Subject: [PATCH] [NIER] Add more ONNX operators to the NIER. --- lib/ir/nier_cfg.ml | 12 ++++++++++++ lib/ir/nier_cfg.mli | 6 ++++++ lib/onnx/onnx.ml | 6 ++++++ 3 files changed, 24 insertions(+) diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml index a90a168..53e32cf 100644 --- a/lib/ir/nier_cfg.ml +++ b/lib/ir/nier_cfg.ml @@ -125,30 +125,42 @@ module Node = struct type operator = | Add + | Sub | Mul + | Div | Matmul + | Gemm | LogSoftmax | ReLu | Transpose | Squeeze | MaxPool | Conv + | Reshape + | Flatten | Identity + | Constant | NO_OP | RW_Linearized_ReLu let str_op o = match o with | Add -> "Add" + | Sub -> "Sub" | Mul -> "Mul" + | Div -> "Div" | Matmul -> "Matmul" + | Gemm -> "Gemm" | LogSoftmax -> "LogSoftmax" | ReLu -> "ReLu" | Transpose -> "Transpose" | Squeeze -> "Squeeze" | MaxPool -> "MaxPool" | Conv -> "Conv" + | Reshape -> "Reshape" + | Flatten -> "Flatten" | Identity -> "Identity" + | Constant -> "Constant" | NO_OP -> "NO_OP" | RW_Linearized_ReLu -> "RW_Linearized_ReLu" diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli index fd8394a..071e7ad 100644 --- a/lib/ir/nier_cfg.mli +++ b/lib/ir/nier_cfg.mli @@ -80,15 +80,21 @@ module Node : sig type operator = | Add + | Sub | Mul + | Div | Matmul + | Gemm | LogSoftmax | ReLu | Transpose | Squeeze | MaxPool | Conv + | Reshape + | Flatten | Identity + | Constant | NO_OP | RW_Linearized_ReLu diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 6a0ea0a..6a7659a 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -110,14 +110,20 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | Some o -> ( match o with | "Add" -> NCFG.Node.Add + | "Sub" -> NCFG.Node.Sub | "Mul" -> NCFG.Node.Mul + | "Div" -> NCFG.Node.Div | "Relu" -> NCFG.Node.ReLu | "MatMul" -> NCFG.Node.Matmul + | "Gemm" -> NCFG.Node.Gemm | "LogSoftmax" -> NCFG.Node.LogSoftmax | "Transpose" -> NCFG.Node.Transpose | "Squeeze" -> NCFG.Node.Squeeze | "MaxPool" -> NCFG.Node.MaxPool + | "Constant" -> NCFG.Node.Constant | "Conv" -> NCFG.Node.Conv + | "Reshape" -> NCFG.Node.Reshape + | "Flatten" -> NCFG.Node.Flatten | "Identity" -> NCFG.Node.Identity | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o))) in -- GitLab