diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml index a90a16848e545c6b6ca2276ca35121ee5bae5d12..53e32cf741a17ec169cf6f076bc01f776aa9dd20 100644 --- a/lib/ir/nier_cfg.ml +++ b/lib/ir/nier_cfg.ml @@ -125,30 +125,42 @@ module Node = struct type operator = | Add + | Sub | Mul + | Div | Matmul + | Gemm | LogSoftmax | ReLu | Transpose | Squeeze | MaxPool | Conv + | Reshape + | Flatten | Identity + | Constant | NO_OP | RW_Linearized_ReLu let str_op o = match o with | Add -> "Add" + | Sub -> "Sub" | Mul -> "Mul" + | Div -> "Div" | Matmul -> "Matmul" + | Gemm -> "Gemm" | LogSoftmax -> "LogSoftmax" | ReLu -> "ReLu" | Transpose -> "Transpose" | Squeeze -> "Squeeze" | MaxPool -> "MaxPool" | Conv -> "Conv" + | Reshape -> "Reshape" + | Flatten -> "Flatten" | Identity -> "Identity" + | Constant -> "Constant" | NO_OP -> "NO_OP" | RW_Linearized_ReLu -> "RW_Linearized_ReLu" diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli index fd8394a50dd1d05ca39faa9ff99b4689b267bdd5..071e7ad14a90ad94d5956fe755e413f7098bd224 100644 --- a/lib/ir/nier_cfg.mli +++ b/lib/ir/nier_cfg.mli @@ -80,15 +80,21 @@ module Node : sig type operator = | Add + | Sub | Mul + | Div | Matmul + | Gemm | LogSoftmax | ReLu | Transpose | Squeeze | MaxPool | Conv + | Reshape + | Flatten | Identity + | Constant | NO_OP | RW_Linearized_ReLu diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 6a0ea0a963cc78954c74c3141c492c5bee87bf6c..6a7659a8d15edcdc7f7ded2dad16e413d4a657bd 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -110,14 +110,20 @@ let produce_cfg (g : Oproto.Onnx.GraphProto.t) = | Some o -> ( match o with | "Add" -> NCFG.Node.Add + | "Sub" -> NCFG.Node.Sub | "Mul" -> NCFG.Node.Mul + | "Div" -> NCFG.Node.Div | "Relu" -> NCFG.Node.ReLu | "MatMul" -> NCFG.Node.Matmul + | "Gemm" -> NCFG.Node.Gemm | "LogSoftmax" -> NCFG.Node.LogSoftmax | "Transpose" -> NCFG.Node.Transpose | "Squeeze" -> NCFG.Node.Squeeze | "MaxPool" -> NCFG.Node.MaxPool + | "Constant" -> NCFG.Node.Constant | "Conv" -> NCFG.Node.Conv + | "Reshape" -> NCFG.Node.Reshape + | "Flatten" -> NCFG.Node.Flatten | "Identity" -> NCFG.Node.Identity | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o))) in