From 3fe5ddfede9df75d20144f1eb9dcb3a73de5a3b6 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Thu, 6 Jun 2024 14:27:49 +0200
Subject: [PATCH] [nnet] Fixed wrong shape being inferred during Nnet nir
 generation

---
 lib/nnet/nnet.ml    | 5 +++--
 tests/nir_to_onnx.t | 8 ++++----
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml
index 42a3fdc..9971901 100644
--- a/lib/nnet/nnet.ml
+++ b/lib/nnet/nnet.ml
@@ -101,14 +101,15 @@ let to_nir weights_biases n_inputs layer_sizes =
       let biases_tensor = Nir.Gentensor.of_float_array (Array.of_list biases) in
       let biases_node = Node.create (Node.Constant { data = biases_tensor }) in
       let add_node =
-        Node.create (Add { input1 = matmul_node; input2 = biases_node })
+        Node.create (Add { input1 = biases_node; input2 = matmul_node })
       in
       let relu_node = Node.create (Node.ReLu { input = add_node }) in
       relu_node
   in
   let in_sh = Shape.of_list [ n_inputs ] in
   let g =
-    Nir.Ngraph.create (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh)
+    Nir.Ngraph.create
+      (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh)
   in
   g
 
diff --git a/tests/nir_to_onnx.t b/tests/nir_to_onnx.t
index 8963344..6eab31a 100644
--- a/tests/nir_to_onnx.t
+++ b/tests/nir_to_onnx.t
@@ -25,7 +25,7 @@ Test verify
 
 Input name should be 0
 
-  $ python3 bin/inspect_onnx.py
+  $ python3 bin/inspect_onnx.py out/nn_onnx.nir.onnx
   out/nn_onnx.nir.onnx has 1 input nodes
   {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}}
   1 files checked
@@ -51,7 +51,7 @@ Input name should be 0
 
 Input name should be 0
 
-  $ python3 bin/inspect_onnx.py
-  out/nn_onnx.nir.onnx has 1 input nodes
-  {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}}
+  $ python3 bin/inspect_onnx.py out_nnet/nn_nnet.nir.onnx
+  out_nnet/nn_nnet.nir.onnx has 1 input nodes
+  {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   1 files checked
-- 
GitLab