From c3542bcba67e4456507dcee357d34fe0f03ac031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Mon, 22 Apr 2024 14:40:50 +0200 Subject: [PATCH] [exps] Add tests for check_pruning --- examples/mnist/check_pruning.mlw | 40 ++++++++++++++++++++++++ src/interpretation/interpreter_theory.ml | 6 +++- tests/check_pruning.t | 7 +++++ tests/dune | 18 ++++++++++- 4 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 examples/mnist/check_pruning.mlw create mode 100644 tests/check_pruning.t diff --git a/examples/mnist/check_pruning.mlw b/examples/mnist/check_pruning.mlw new file mode 100644 index 0000000..37f3fc9 --- /dev/null +++ b/examples/mnist/check_pruning.mlw @@ -0,0 +1,40 @@ +theory MNIST + + use caisar.types.Vector + use ieee_float.Float64 + use caisar.types.Float64WithBounds as Feature + use caisar.types.IntWithBounds as Label + use caisar.model.Model + use caisar.dataset.CSV + use caisar.robust.ClassRobustCSV + use caisar.robust.ClassRobustVector + + constant model_filename: string + constant pruned_model_filename: string + constant dataset_filename: string + + constant label_bounds: Label.bounds = + Label.{ lower = 0; upper = 9 } + + constant feature_bounds: Feature.bounds = + Feature.{ lower = (0.0:t); upper = (1.0:t) } + + goal pruned: + let nn = read_model model_filename in + let pruned_nn = read_model pruned_model_filename in + let dataset = read_dataset dataset_filename in + let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + let delta = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + CSV.forall_ dataset (fun _ e -> + forall perturbed_e. + has_length perturbed_e (length e) -> + FeatureVector.valid feature_bounds perturbed_e -> + let perturbation = perturbed_e - e in + ClassRobustVector.bounded_by_epsilon perturbation eps -> + let out = nn@@perturbed_e in + let pruned_out = pruned_nn@@perturbed_e in + .- delta .<= out[0] .- pruned_out[0] .<= delta + ) + + +end diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml index 6f079ff..29f27b7 100644 --- a/src/interpretation/interpreter_theory.ml +++ b/src/interpretation/interpreter_theory.ml @@ -128,7 +128,11 @@ module Vector = struct interpreter_op) | None -> IRE.reconstruct_term ()) | [ Term _t1; Term _t2 ] -> IRE.reconstruct_term () - | _ -> fail_on_unexpected_argument ls + | _ -> Logging.code_error ~src:Logging.src_interpret_goal (fun m -> + m "Unexpected argument(s) for '%a': %a" Why3.Pretty.print_ls ls + (Fmt.list ~sep:Fmt.comma IRE.pp_value) vl + ) + let length : _ IRE.builtin = fun engine ls vl _ty -> diff --git a/tests/check_pruning.t b/tests/check_pruning.t new file mode 100644 index 0000000..adb8e56 --- /dev/null +++ b/tests/check_pruning.t @@ -0,0 +1,7 @@ + $ . ./setup_env.sh + + $ ls ../examples/ + acasxu + mnist + +# $ caisar verify --prover PyRAT --ltag=StackTrace --define model_filename:nets/dummy_nn/FNN_s42.onnx --define pruned_model_filename:nets/dummy_nn/pruned_FNN_s42.onnx --define dataset_filename:csv/single_image.csv ../examples/mnist/check_pruning.mlw -v diff --git a/tests/dune b/tests/dune index f30bdf7..9a2305d 100644 --- a/tests/dune +++ b/tests/dune @@ -1,6 +1,6 @@ (cram (alias local) - (applies_to * \ nir_to_onnx acasxu_ci) + (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning) (deps (package caisar) setup_env.sh @@ -27,6 +27,22 @@ ) (package caisar)) + (cram + (alias local) + (applies_to check_pruning) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/mnist/check_pruning.mlw + ../examples/mnist/nets/dummy_nn/FNN_s42.onnx + ../examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx + ../examples/mnist/csv/single_image.csv + ) + (package caisar)) + + (cram (alias ci) (deps -- GitLab