From 050e1ceb19b3cb27f39c30d141b1442e86ac17fe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Mon, 22 Apr 2024 15:11:35 +0200
Subject: [PATCH] [exps] Add tests for split_nn

---
 examples/mnist/splitted_nn.mlw | 40 ++++++++++++++++++++++++++++++++++
 tests/dune                     | 17 ++++++++++++++-
 tests/splitted_nn.t            |  7 ++++++
 3 files changed, 63 insertions(+), 1 deletion(-)
 create mode 100644 examples/mnist/splitted_nn.mlw
 create mode 100644 tests/splitted_nn.t

diff --git a/examples/mnist/splitted_nn.mlw b/examples/mnist/splitted_nn.mlw
new file mode 100644
index 0000000..6afbfd5
--- /dev/null
+++ b/examples/mnist/splitted_nn.mlw
@@ -0,0 +1,40 @@
+theory MNIST
+
+  use caisar.types.Vector
+  use ieee_float.Float64
+  use caisar.types.Float64WithBounds as Feature
+  use caisar.types.IntWithBounds as Label
+  use caisar.model.Model
+  use caisar.dataset.CSV
+  use caisar.robust.ClassRobustCSV
+  use caisar.robust.ClassRobustVector
+
+  constant pre_model_filename: string
+  constant post_model_filename: string
+  constant dataset_filename: string
+
+  constant label_bounds: Label.bounds =
+    Label.{ lower = 0; upper = 9 }
+  
+  constant feature_bounds: Feature.bounds =
+    Feature.{ lower = (0.0:t); upper = (1.0:t) }
+
+  goal pruned:
+    let pre_nn = read_model pre_model_filename in
+    let post_nn = read_model post_model_filename in
+    let dataset = read_dataset dataset_filename in
+    let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in
+    CSV.forall_ dataset (fun l e ->
+        forall perturbed_e.
+          has_length perturbed_e (length e) ->
+          FeatureVector.valid feature_bounds perturbed_e ->
+          let perturbation = perturbed_e - e in
+          ClassRobustVector.bounded_by_epsilon perturbation eps ->
+          let out1 = pre_nn@@perturbed_e in
+          let out2 = post_nn@@out1 in
+          forall j. Label.valid label_bounds j -> j <> l ->
+          out2[l] .>= out2[j]
+     )
+
+
+end
diff --git a/tests/dune b/tests/dune
index 9a2305d..0020300 100644
--- a/tests/dune
+++ b/tests/dune
@@ -1,6 +1,6 @@
 (cram
  (alias local)
- (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning)
+ (applies_to * \ nir_to_onnx acasxu_ci arithmetic check_pruning splited_nn)
  (deps
   (package caisar)
   setup_env.sh
@@ -42,6 +42,21 @@
   )
  (package caisar))
 
+ (cram
+ (alias local)
+ (applies_to splitted_nn)
+ (deps
+  (package caisar)
+  setup_env.sh
+  (glob_files bin/*)
+  filter_tmpdir.sh
+  ../examples/mnist/splitted_nn.mlw
+  ../examples/mnist/nets/dummy_nn/fnn_pre_s42.onnx
+  ../examples/mnist/nets/dummy_nn/fnn_post_s42.onnx
+  ../examples/mnist/csv/single_image.csv
+  )
+ (package caisar))
+
 
 (cram
  (alias ci)
diff --git a/tests/splitted_nn.t b/tests/splitted_nn.t
new file mode 100644
index 0000000..b03f526
--- /dev/null
+++ b/tests/splitted_nn.t
@@ -0,0 +1,7 @@
+  $ . ./setup_env.sh
+
+  $ ls ../examples/
+  acasxu
+  mnist
+
+#  $ caisar verify --prover PyRAT  --ltag=StackTrace --define pre_model_filename:nets/dummy_nn/fnn_pre_s42.onnx --define post_model_filename:nets/dummy_nn/fnn_post_s42.onnx --define dataset_filename:csv/single_image.csv ../examples/mnist/splitted_nn.mlw -v
-- 
GitLab