From eb94f8c81c273a6e7c81c45f1afe9dcadda4cbe9 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Fri, 17 Mar 2023 11:36:00 +0100
Subject: [PATCH] [interpretation] wip2.

---
 src/interpretation.ml          | 117 ++++++++++++++++++++++-----------
 tests/interpretation_dataset.t |   4 +-
 2 files changed, 82 insertions(+), 39 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index b5e51e6..fbd18c2 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -41,6 +41,14 @@ type caisar_op =
       [@printer
         fun fmt (t1, t2) ->
           Fmt.pf fmt "%a[%a]" Pretty.print_term t1 Pretty.print_term t2]
+  | EqualShape of Term.term * Term.term
+      [@printer
+        fun fmt (t1, t2) ->
+          Fmt.pf fmt "EqShape %a %a" Pretty.print_term t1 Pretty.print_term t2]
+  | ValidIndex of Term.term * Term.term
+      [@printer
+        fun fmt (t1, t2) ->
+          Fmt.pf fmt "ValidIdx %a %a" Pretty.print_term t1 Pretty.print_term t2]
 [@@deriving show]
 
 type caisar_env = {
@@ -90,21 +98,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       Fmt.(option ~none:nop Pretty.print_ty)
       ty;
     match vl with
-    | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
-      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
-      Term (term_of_caisar_op engine (VGet (t1, t2)) ty)
     | [
-     Term ({ t_node = Tapp (ls, _); _ } as t1);
+     Term ({ t_node = Tapp (lsapp, _); _ } as t1);
      Term ({ t_node = Tconst (ConstInt i); _ } as t2);
-    ] ->
+    ] -> (
       Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
-      let t_features, t_label =
-        let row =
-          match caisar_op_of_ls engine ls with
-          | Dataset (CSV csv) -> List.nth_exn csv (Number.to_small_integer i)
-          | Data _ | Classifier _ | ClassifierApp (_, _) | VGet (_, _) ->
-            assert false
-        in
+      match caisar_op_of_ls engine lsapp with
+      | Dataset (CSV csv) ->
+        let row = List.nth_exn csv (Number.to_small_integer i) in
         let label, features =
           match row with
           | [] | [ _ ] -> assert false
@@ -115,10 +116,24 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
           | Some { ty_node = Tyapp (_, [ a; _ ]); _ } -> Some a
           | _ -> assert false
         in
-        ( term_of_caisar_op engine (Data features) ty_features,
-          Term.t_int_const (BigInt.of_int (Int.of_string label)) )
-      in
-      Term (Term.t_tuple [ t_features; t_label ])
+        let t_features, t_label =
+          ( term_of_caisar_op engine (Data features) ty_features,
+            Term.t_int_const (BigInt.of_int (Int.of_string label)) )
+        in
+        Term (Term.t_tuple [ t_features; t_label ])
+      | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
+      | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
+        assert false)
+    | [
+     Term ({ t_node = Tapp (lsapp, _); _ } as t1);
+     Term ({ t_node = Tvar _; _ } as t2);
+    ] -> (
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      match caisar_op_of_ls engine lsapp with
+      | Dataset _ -> assert false
+      | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
+      | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
+        assert false)
     | _ -> invalid_arg (error_message ls)
   in
   let length : _ CRE.builtin =
@@ -130,25 +145,48 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
     | [ Term { t_node = Tapp (ls, []); _ } ] -> (
       match caisar_op_of_ls engine ls with
       | Dataset (CSV csv) -> Int (BigInt.of_int (Csv.lines csv))
-      | Data _ | Classifier _ | ClassifierApp _ | VGet _ -> assert false)
+      | Data _ | Classifier _ | ClassifierApp _ | VGet _ | EqualShape _
+      | ValidIndex _ ->
+        assert false)
+    | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (term_of_caisar_op engine (VGet (t1, t2)) ty)
     | _ -> invalid_arg (error_message ls)
   in
 
   (* Tensor *)
-  (* let valid_index : _ CRE.builtin = *)
-  (*  fun _engine ls _vl ty -> *)
-  (*   Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls *)
-  (*     Fmt.(option ~none:nop Pretty.print_ty) *)
-  (*     ty; *)
-  (*   Term Term.t_true *)
-  (* in *)
-  (* let equal_shape : _ CRE.builtin = *)
-  (*  fun _engine ls _vl ty -> *)
-  (*   Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls *)
-  (*     Fmt.(option ~none:nop Pretty.print_ty) *)
-  (*     ty; *)
-  (*   Term Term.t_true *)
-  (* in *)
+  let _valid_index : _ CRE.builtin =
+   fun engine ls vl ty ->
+    Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls
+      Fmt.(option ~none:nop Pretty.print_ty)
+      ty;
+    match vl with
+    | [
+        Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
+      ]
+    | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (term_of_caisar_op engine (ValidIndex (t1, t2)) ty)
+      (* Term Term.t_true *)
+    | _ -> invalid_arg (error_message ls)
+  in
+  let _equal_shape : _ CRE.builtin =
+   fun engine ls vl ty ->
+    Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls
+      Fmt.(option ~none:nop Pretty.print_ty)
+      ty;
+    match vl with
+    | [
+     Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
+    ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (Term.t_app_infer ls [ t1; t2 ])
+    | [ Term t1; Term t2 ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (term_of_caisar_op engine (EqualShape (t1, t2)) ty)
+      (* Term Term.t_true *)
+    | _ -> invalid_arg (error_message ls)
+  in
 
   (* Classifier *)
   let read_classifier : _ CRE.builtin =
@@ -176,6 +214,12 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       ty;
     match vl with
     | [ Term ({ t_node = Tvar _; _ } as t1); Term t2 ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (Term.t_app_infer ls [ t1; t2 ])
+    | [ Term ({ t_node = Tapp (_lsapp, _); _ } as t1); Term t2 ] ->
+      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
+      Term (Term.t_app_infer ls [ t1; t2 ])
+    | [ Term t1; Term t2 ] ->
       Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
       Term (term_of_caisar_op engine (ClassifierApp (t1, t2)) ty)
     | _ -> invalid_arg (error_message ls)
@@ -206,12 +250,11 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       "Vector",
       [],
       [ (Ident.op_get "" (* ([]) *), None, vget); ("length", None, length) ] );
-    (* ( [ "interpretation" ], *)
-    (*   "Tensor", *)
-    (*   [], *)
-    (* [ ("valid_index", None, valid_index); ("equal_shape", None, equal_shape)
-       ] *)
-    (* ); *)
+    ( [ "interpretation" ],
+      "Tensor",
+      [],
+      [ (* ("valid_index", None, valid_index); *)
+        (* ("equal_shape", None, equal_shape); *) ] );
     ( [ "interpretation" ],
       "Classifier",
       [],
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index a933f78..ca4ded6 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -20,11 +20,11 @@ Test interpret on dataset
   >   predicate valid_image (i: image) =
   >     forall v: index. valid_index i v -> (0.0: t) .<= i#v .<= (1.0: t)
   > 
-  >   predicate valid_label (l: label_) = 0 <= l <= 1
+  >   predicate valid_label (l: label_) = 0 <= l <= 2
   > 
   >   predicate advises (c: classifier) (i: image) (l: label_) =
   >     valid_label l ->
-  >     forall j: int. valid_label j /\  j <> l -> (c@@i)[l] .> (c@@i)[j]
+  >     forall j: int. valid_label j ->  j <> l -> (c@@i)[l] .> (c@@i)[j]
   > 
   >   predicate bounded_by_epsilon (i: image) (eps: t) =
   >     forall v: index. valid_index i v -> .- eps .<= i#v .<= eps
-- 
GitLab