From 3fe5ddfede9df75d20144f1eb9dcb3a73de5a3b6 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Thu, 6 Jun 2024 14:27:49 +0200 Subject: [PATCH] [nnet] Fixed wrong shape being inferred during Nnet nir generation --- lib/nnet/nnet.ml | 5 +++-- tests/nir_to_onnx.t | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index 42a3fdc..9971901 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -101,14 +101,15 @@ let to_nir weights_biases n_inputs layer_sizes = let biases_tensor = Nir.Gentensor.of_float_array (Array.of_list biases) in let biases_node = Node.create (Node.Constant { data = biases_tensor }) in let add_node = - Node.create (Add { input1 = matmul_node; input2 = biases_node }) + Node.create (Add { input1 = biases_node; input2 = matmul_node }) in let relu_node = Node.create (Node.ReLu { input = add_node }) in relu_node in let in_sh = Shape.of_list [ n_inputs ] in let g = - Nir.Ngraph.create (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh) + Nir.Ngraph.create + (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh) in g diff --git a/tests/nir_to_onnx.t b/tests/nir_to_onnx.t index 8963344..6eab31a 100644 --- a/tests/nir_to_onnx.t +++ b/tests/nir_to_onnx.t @@ -25,7 +25,7 @@ Test verify Input name should be 0 - $ python3 bin/inspect_onnx.py + $ python3 bin/inspect_onnx.py out/nn_onnx.nir.onnx out/nn_onnx.nir.onnx has 1 input nodes {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}} 1 files checked @@ -51,7 +51,7 @@ Input name should be 0 Input name should be 0 - $ python3 bin/inspect_onnx.py - out/nn_onnx.nir.onnx has 1 input nodes - {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}} + $ python3 bin/inspect_onnx.py out_nnet/nn_nnet.nir.onnx + out_nnet/nn_nnet.nir.onnx has 1 input nodes + {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} 1 files checked -- GitLab