From 6bc40b2f78b091f7fa8da992b82b485112eb0798 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Fri, 5 Apr 2024 13:52:12 +0200 Subject: [PATCH] [NN_prover] make the onnx file used deterministic when $CAISAR_ONNX_OUTPUT_DIR is specified --- lib/onnx/tests/print.expected | 2 ++ src/transformations/native_nn_prover.ml | 17 ++++++++++++++--- tests/acasxu.t | 10 +++++++++- tests/bin/inspect_onnx.py | 7 ++++++- tests/nier_to_onnx.t | 2 ++ 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/lib/onnx/tests/print.expected b/lib/onnx/tests/print.expected index 8088b70..79aac1a 100644 --- a/lib/onnx/tests/print.expected +++ b/lib/onnx/tests/print.expected @@ -1,4 +1,6 @@ true ok ok +test.onnx has 1 input nodes +{'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '5'}]}}}} 1 files checked diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 3bac141..9dc696e 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -31,6 +31,16 @@ type new_output = { exception UnknownLogicSymbol +let tempfile = + let c = ref (-1) in + fun () -> + match Sys.getenv "CAISAR_ONNX_OUTPUT_DIR" with + | Some dir -> + (* deterministic *) + Int.incr c; + Stdlib.Filename.concat dir (Fmt.str "caisar_%i.onnx" !c) + | None -> Stdlib.Filename.temp_file "caisar" ".onnx" + let create_new_nn env input_vars outputs : string = let module IR = Ir.Nier_simple in let th_f64 = Symbols.Float64.create env in @@ -179,14 +189,15 @@ let create_new_nn env input_vars outputs : string = let outputs = List.rev_map outputs ~f:(fun { index; term } -> (index, convert_term term)) |> List.sort ~compare:(fun (i, _) (j, _) -> Int.compare i j) - |> List.map ~f:snd + |> List.mapi ~f:(fun i (j, n) -> + assert (i = j); + n) in let output = IR.Node.create (Concat { inputs = outputs; axis = 0 }) in assert ( IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |])); let nn = IR.create output in - let temp_dir = Sys.getenv "CAISAR_ONNX_OUTPUT_DIR" in - let filename = Stdlib.Filename.temp_file ?temp_dir "caisar" ".onnx" in + let filename = tempfile () in Onnx.Simple.write nn filename; filename diff --git a/tests/acasxu.t b/tests/acasxu.t index 2612b32..3f6f158 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -1184,4 +1184,12 @@ Test verify on acasxu Goal P3: Unknown () $ python3 bin/inspect_onnx.py - 12 files checked + caisar_0.onnx has 1 input nodes + {'name': '38', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + caisar_1.onnx has 1 input nodes + {'name': '135', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + caisar_2.onnx has 1 input nodes + {'name': '299', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '7'}]}}}} + caisar_3.onnx has 1 input nodes + {'name': '468', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + 4 files checked diff --git a/tests/bin/inspect_onnx.py b/tests/bin/inspect_onnx.py index 800f144..16706ba 100644 --- a/tests/bin/inspect_onnx.py +++ b/tests/bin/inspect_onnx.py @@ -1,10 +1,15 @@ import onnx +from google.protobuf.json_format import MessageToDict import os l = os.listdir("out") +l.sort() for file in l: m = onnx.load(os.path.join("out", file)) onnx.checker.check_model(m) + print (f"{file} has {len(m.graph.input)} input nodes") + for _input in m.graph.input: + print(MessageToDict(_input)) -print(len(l),"files checked") \ No newline at end of file +print(len(l),"files checked") diff --git a/tests/nier_to_onnx.t b/tests/nier_to_onnx.t index d754d03..913fd9a 100644 --- a/tests/nier_to_onnx.t +++ b/tests/nier_to_onnx.t @@ -26,4 +26,6 @@ Test verify Data should be 0.135 $ python3 bin/inspect_onnx.py + nn_onnx.nier.onnx has 1 input nodes + {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}} 1 files checked -- GitLab