From 050e1ceb19b3cb27f39c30d141b1442e86ac17fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Mon, 22 Apr 2024 15:11:35 +0200 Subject: [PATCH] [exps] Add tests for split_nn --- examples/mnist/splitted_nn.mlw | 40 ++++++++++++++++++++++++++++++++++ tests/dune | 17 ++++++++++++++- tests/splitted_nn.t | 7 ++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 examples/mnist/splitted_nn.mlw create mode 100644 tests/splitted_nn.t diff --git a/examples/mnist/splitted_nn.mlw b/examples/mnist/splitted_nn.mlw new file mode 100644 index 0000000..6afbfd5 --- /dev/null +++ b/examples/mnist/splitted_nn.mlw @@ -0,0 +1,40 @@ +theory MNIST + + use caisar.types.Vector + use ieee_float.Float64 + use caisar.types.Float64WithBounds as Feature + use caisar.types.IntWithBounds as Label + use caisar.model.Model + use caisar.dataset.CSV + use caisar.robust.ClassRobustCSV + use caisar.robust.ClassRobustVector + + constant pre_model_filename: string + constant post_model_filename: string + constant dataset_filename: string + + constant label_bounds: Label.bounds = + Label.{ lower = 0; upper = 9 } + + constant feature_bounds: Feature.bounds = + Feature.{ lower = (0.0:t); upper = (1.0:t) } + + goal pruned: + let pre_nn = read_model pre_model_filename in + let post_nn = read_model post_model_filename in + let dataset = read_dataset dataset_filename in + let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + CSV.forall_ dataset (fun l e -> + forall perturbed_e. + has_length perturbed_e (length e) -> + FeatureVector.valid feature_bounds perturbed_e -> + let perturbation = perturbed_e - e in + ClassRobustVector.bounded_by_epsilon perturbation eps -> + let out1 = pre_nn@@perturbed_e in + let out2 = post_nn@@out1 in + forall j. Label.valid label_bounds j -> j <> l -> + out2[l] .>= out2[j] + ) + + +end diff --git a/tests/dune b/tests/dune index 9a2305d..0020300 100644 --- a/tests/dune +++ b/tests/dune @@ -1,6 +1,6 @@ (cram (alias local) - (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning) + (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning splited_nn) (deps (package caisar) setup_env.sh @@ -42,6 +42,21 @@ ) (package caisar)) + (cram + (alias local) + (applies_to splitted_nn) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/mnist/splitted_nn.mlw + ../examples/mnist/nets/dummy_nn/fnn_pre_s42.onnx + ../examples/mnist/nets/dummy_nn/fnn_post_s42.onnx + ../examples/mnist/csv/single_image.csv + ) + (package caisar)) + (cram (alias ci) diff --git a/tests/splitted_nn.t b/tests/splitted_nn.t new file mode 100644 index 0000000..b03f526 --- /dev/null +++ b/tests/splitted_nn.t @@ -0,0 +1,7 @@ + $ . ./setup_env.sh + + $ ls ../examples/ + acasxu + mnist + +# $ caisar verify --prover PyRAT --ltag=StackTrace --define pre_model_filename:nets/dummy_nn/fnn_pre_s42.onnx --define post_model_filename:nets/dummy_nn/fnn_post_s42.onnx --define dataset_filename:csv/single_image.csv ../examples/mnist/splitted_nn.mlw -v -- GitLab