From c3542bcba67e4456507dcee357d34fe0f03ac031 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Mon, 22 Apr 2024 14:40:50 +0200
Subject: [PATCH] [exps] Add tests for check_pruning

---
 examples/mnist/check_pruning.mlw         | 40 ++++++++++++++++++++++++
 src/interpretation/interpreter_theory.ml |  6 +++-
 tests/check_pruning.t                    |  7 +++++
 tests/dune                               | 18 ++++++++++-
 4 files changed, 69 insertions(+), 2 deletions(-)
 create mode 100644 examples/mnist/check_pruning.mlw
 create mode 100644 tests/check_pruning.t

diff --git a/examples/mnist/check_pruning.mlw b/examples/mnist/check_pruning.mlw
new file mode 100644
index 0000000..37f3fc9
--- /dev/null
+++ b/examples/mnist/check_pruning.mlw
@@ -0,0 +1,40 @@
+theory MNIST
+
+  use caisar.types.Vector
+  use ieee_float.Float64
+  use caisar.types.Float64WithBounds as Feature
+  use caisar.types.IntWithBounds as Label
+  use caisar.model.Model
+  use caisar.dataset.CSV
+  use caisar.robust.ClassRobustCSV
+  use caisar.robust.ClassRobustVector
+
+  constant model_filename: string
+  constant pruned_model_filename: string
+  constant dataset_filename: string
+
+  constant label_bounds: Label.bounds =
+    Label.{ lower = 0; upper = 9 }
+  
+  constant feature_bounds: Feature.bounds =
+    Feature.{ lower = (0.0:t); upper = (1.0:t) }
+
+  goal pruned:
+    let nn = read_model model_filename in
+    let pruned_nn = read_model pruned_model_filename in
+    let dataset = read_dataset dataset_filename in
+    let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in
+    let delta = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in
+    CSV.forall_ dataset (fun _ e ->
+        forall perturbed_e.
+          has_length perturbed_e (length e) ->
+          FeatureVector.valid feature_bounds perturbed_e ->
+          let perturbation = perturbed_e - e in
+          ClassRobustVector.bounded_by_epsilon perturbation eps ->
+          let out = nn@@perturbed_e in
+          let pruned_out = pruned_nn@@perturbed_e in
+          .- delta .<= out[0] .- pruned_out[0] .<= delta
+     )
+
+
+end
diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml
index 6f079ff..29f27b7 100644
--- a/src/interpretation/interpreter_theory.ml
+++ b/src/interpretation/interpreter_theory.ml
@@ -128,7 +128,11 @@ module Vector = struct
             interpreter_op)
       | None -> IRE.reconstruct_term ())
     | [ Term _t1; Term _t2 ] -> IRE.reconstruct_term ()
-    | _ -> fail_on_unexpected_argument ls
+    | _ -> Logging.code_error ~src:Logging.src_interpret_goal (fun m ->
+      m "Unexpected argument(s) for '%a': %a" Why3.Pretty.print_ls ls
+        (Fmt.list ~sep:Fmt.comma IRE.pp_value) vl     
+      )
+  
 
   let length : _ IRE.builtin =
    fun engine ls vl _ty ->
diff --git a/tests/check_pruning.t b/tests/check_pruning.t
new file mode 100644
index 0000000..adb8e56
--- /dev/null
+++ b/tests/check_pruning.t
@@ -0,0 +1,7 @@
+  $ . ./setup_env.sh
+
+  $ ls ../examples/
+  acasxu
+  mnist
+
+#  $ caisar verify --prover PyRAT  --ltag=StackTrace --define model_filename:nets/dummy_nn/FNN_s42.onnx --define pruned_model_filename:nets/dummy_nn/pruned_FNN_s42.onnx --define dataset_filename:csv/single_image.csv ../examples/mnist/check_pruning.mlw -v
diff --git a/tests/dune b/tests/dune
index f30bdf7..9a2305d 100644
--- a/tests/dune
+++ b/tests/dune
@@ -1,6 +1,6 @@
 (cram
  (alias local)
- (applies_to * \ nir_to_onnx acasxu_ci)
+ (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning)
  (deps
   (package caisar)
   setup_env.sh
@@ -27,6 +27,22 @@
   )
  (package caisar))
 
+ (cram
+ (alias local)
+ (applies_to check_pruning)
+ (deps
+  (package caisar)
+  setup_env.sh
+  (glob_files bin/*)
+  filter_tmpdir.sh
+  ../examples/mnist/check_pruning.mlw
+  ../examples/mnist/nets/dummy_nn/FNN_s42.onnx
+  ../examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx
+  ../examples/mnist/csv/single_image.csv
+  )
+ (package caisar))
+
+
 (cram
  (alias ci)
  (deps
-- 
GitLab