diff --git a/examples/mnist/check_pruning.mlw b/examples/mnist/check_pruning.mlw new file mode 100644 index 0000000000000000000000000000000000000000..37f3fc9ceb8ff9858de9b1ef3bb8c69f5b60162a --- /dev/null +++ b/examples/mnist/check_pruning.mlw @@ -0,0 +1,40 @@ +theory MNIST + + use caisar.types.Vector + use ieee_float.Float64 + use caisar.types.Float64WithBounds as Feature + use caisar.types.IntWithBounds as Label + use caisar.model.Model + use caisar.dataset.CSV + use caisar.robust.ClassRobustCSV + use caisar.robust.ClassRobustVector + + constant model_filename: string + constant pruned_model_filename: string + constant dataset_filename: string + + constant label_bounds: Label.bounds = + Label.{ lower = 0; upper = 9 } + + constant feature_bounds: Feature.bounds = + Feature.{ lower = (0.0:t); upper = (1.0:t) } + + goal pruned: + let nn = read_model model_filename in + let pruned_nn = read_model pruned_model_filename in + let dataset = read_dataset dataset_filename in + let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + let delta = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + CSV.forall_ dataset (fun _ e -> + forall perturbed_e. + has_length perturbed_e (length e) -> + FeatureVector.valid feature_bounds perturbed_e -> + let perturbation = perturbed_e - e in + ClassRobustVector.bounded_by_epsilon perturbation eps -> + let out = nn@@perturbed_e in + let pruned_out = pruned_nn@@perturbed_e in + .- delta .<= out[0] .- pruned_out[0] .<= delta + ) + + +end diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml index 6f079ff1c9579000613adab0951dad4d0e780e98..29f27b7bde4b21810e41fb9b40def245457083e9 100644 --- a/src/interpretation/interpreter_theory.ml +++ b/src/interpretation/interpreter_theory.ml @@ -128,7 +128,11 @@ module Vector = struct interpreter_op) | None -> IRE.reconstruct_term ()) | [ Term _t1; Term _t2 ] -> IRE.reconstruct_term () - | _ -> fail_on_unexpected_argument ls + | _ -> Logging.code_error ~src:Logging.src_interpret_goal (fun m -> + m "Unexpected argument(s) for '%a': %a" Why3.Pretty.print_ls ls + (Fmt.list ~sep:Fmt.comma IRE.pp_value) vl + ) + let length : _ IRE.builtin = fun engine ls vl _ty -> diff --git a/tests/check_pruning.t b/tests/check_pruning.t new file mode 100644 index 0000000000000000000000000000000000000000..adb8e56e50b424a7d5bc69a8444be7faf8545180 --- /dev/null +++ b/tests/check_pruning.t @@ -0,0 +1,7 @@ + $ . ./setup_env.sh + + $ ls ../examples/ + acasxu + mnist + +# $ caisar verify --prover PyRAT --ltag=StackTrace --define model_filename:nets/dummy_nn/FNN_s42.onnx --define pruned_model_filename:nets/dummy_nn/pruned_FNN_s42.onnx --define dataset_filename:csv/single_image.csv ../examples/mnist/check_pruning.mlw -v diff --git a/tests/dune b/tests/dune index f30bdf7039193dd194d1ac051a93f33eff5d591d..9a2305d9a90c5140956a00dc652db2b3dd557730 100644 --- a/tests/dune +++ b/tests/dune @@ -1,6 +1,6 @@ (cram (alias local) - (applies_to * \ nir_to_onnx acasxu_ci) + (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning) (deps (package caisar) setup_env.sh @@ -27,6 +27,22 @@ ) (package caisar)) + (cram + (alias local) + (applies_to check_pruning) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/mnist/check_pruning.mlw + ../examples/mnist/nets/dummy_nn/FNN_s42.onnx + ../examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx + ../examples/mnist/csv/single_image.csv + ) + (package caisar)) + + (cram (alias ci) (deps