From 845ecadbfc550a2aa33fb54c825d2fffbdf98f68 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 5 Apr 2023 11:03:23 +0200
Subject: [PATCH] [interpretation] Use equal_shape only for tensors, not
 vectors.

---
 src/interpretation.ml          | 22 +++++++---------------
 stdlib/interpretation.mlw      |  1 -
 tests/interpretation_dataset.t | 13 ++++++-------
 3 files changed, 13 insertions(+), 23 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index 69d488b..238359d 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -157,7 +157,9 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
     | [ Term { t_node = Tapp (ls, []); _ } ] -> (
       match caisar_op_of_ls engine ls with
       | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv))
-      | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false)
+      | Vector v -> int (BigInt.of_int (Language.lookup_vector v))
+      | Data (D_csv data) -> int (BigInt.of_int (List.length data))
+      | Classifier _ | Tensor _ | Index _ -> assert false)
     | [ Term { t_node = Tapp (ls, tl); _ } ] -> (
       match caisar_op_of_ls engine ls with
       | Vector v ->
@@ -410,29 +412,19 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
         | Data (D_csv d) -> List.length d
         | _ -> assert false
       in
-      let ty, caisar_op, id =
+      let ty =
         match vs.vs_ty with
-        | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } ->
-          let caisar_op, id =
-            if String.equal ts_name.id_string "vector"
-            then
-              let { env; _ } = CRE.user_env engine in
-              (Vector (Language.create_vector env n), "caisar_v")
-            else if String.equal ts_name.id_string "tensor"
-            then (Tensor n, "caisar_t")
-            else assert false
-          in
-          (ty, caisar_op, id)
+        | { ty_node = Tyapp (_, ty :: _); _ } -> ty
         | _ -> assert false
       in
       let new_quant =
         List.init n ~f:(fun _ ->
-          let preid = Ident.id_fresh id in
+          let preid = Ident.id_fresh "caisar_t" in
           Term.create_vsymbol preid ty)
       in
       let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in
       let substitutions =
-        [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ]
+        [ term_of_caisar_op ~args engine (Tensor n) (Some vs.vs_ty) ]
       in
       Some { new_quant; substitutions }
   | Tapp
diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw
index a6b53ef..ab3b277 100644
--- a/stdlib/interpretation.mlw
+++ b/stdlib/interpretation.mlw
@@ -31,7 +31,6 @@ theory Vector
   function (-) (v1: vector 'a) (v2: vector 'a) : vector 'a
 
   predicate has_length (v: vector 'a) (i: int)
-  predicate equal_shape (v1: vector 'a) (v2: vector 'b)
   predicate valid_index (v: vector 'a) (i: index) = 0 <= i < length v
 
   scope L
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index 4336ff2..2fd6c14 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -30,7 +30,7 @@ Test interpret on dataset
   > 
   >   predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) =
   >     forall perturbed_image: image.
-  >       equal_shape i perturbed_image ->
+  >       has_length perturbed_image (length i) ->
   >       valid_image perturbed_image ->
   >       let perturbation = perturbed_image - i in
   >       bounded_by_epsilon perturbation eps ->
@@ -129,16 +129,15 @@ Test interpret on dataset
         (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0])
   caisar_op1,
-  (Interpretation.Data
-     (Interpretation.D_csv
-        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-  vector, (Interpretation.Vector 5)
-  caisar_op2,
   (Interpretation.Data
      (Interpretation.D_csv
         ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
+  vector, (Interpretation.Vector 5)
   caisar_op,
   (Interpretation.Classifier
      "$TESTCASE_ROOT/TestNetwork.nnet")
+  caisar_op2, (Interpretation.Dataset <csv>)
   caisar_op3,
-  (Interpretation.Dataset <csv>)
+  (Interpretation.Data
+     (Interpretation.D_csv
+        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-- 
GitLab