From 6fd94e489ba1e0645d128848a830ab44afdf6c91 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 5 Apr 2023 10:49:55 +0200
Subject: [PATCH] [interpretation] Add notion of vector in language and use it
 in interpretation.

---
 src/interpretation.ml          | 54 ++++++++++++++++++++++++----------
 src/language.ml                | 25 ++++++++++++++++
 src/language.mli               |  5 ++++
 tests/interpretation_acasxu.t  | 41 +++++++++++++-------------
 tests/interpretation_dataset.t | 32 ++++++++------------
 5 files changed, 100 insertions(+), 57 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index 21e6736..69d488b 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -31,12 +31,17 @@ type classifier = string [@@deriving show]
 type data = D_csv of string list [@@deriving show]
 type index = I_csv of int [@@deriving show]
 
+type vector =
+  (Language.vector
+  [@printer fun fmt v -> Fmt.pf fmt "%d" (Language.lookup_vector v)])
+[@@deriving show]
+
 type caisar_op =
   | Classifier of classifier
   | Dataset of dataset
   | Data of data
   | Index of index
-  | Vector of int
+  | Vector of vector
   | Tensor of int
 [@@deriving show]
 
@@ -53,7 +58,9 @@ let ls_of_caisar_op engine op ty_args ty =
   (* Option.iter ty ~f:(Fmt.pr "ty: %a@." Pretty.print_ty); *)
   Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () ->
     let id = Ident.id_fresh "caisar_op" in
-    let ls = Term.create_lsymbol id ty_args ty in
+    let ls =
+      match op with Vector v -> v | _ -> Term.create_lsymbol id ty_args ty
+    in
     (* Fmt.pr "ls: %a@." Pretty.print_ls ls; *)
     Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls;
     Term.Hls.add caisar_env.caisar_op_of_ls ls op;
@@ -131,7 +138,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
             Term.t_int_const (BigInt.of_int (Int.of_string label)) )
         in
         term (Term.t_tuple [ t_features; t_label ])
-      | Vector n ->
+      | Vector v ->
+        let n = Language.lookup_vector v in
         assert (List.length tl1 = n && i <= n);
         term (List.nth_exn tl1 i)
       | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
@@ -152,7 +160,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false)
     | [ Term { t_node = Tapp (ls, tl); _ } ] -> (
       match caisar_op_of_ls engine ls with
-      | Vector n ->
+      | Vector v ->
+        let n = Language.lookup_vector v in
         assert (List.length tl = n);
         int (BigInt.of_int n)
       | Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
@@ -173,7 +182,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
     ] -> (
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
       match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with
-      | Vector n, Data (D_csv data) ->
+      | Vector v, Data (D_csv data) ->
+        let n = Language.lookup_vector v in
         assert (n = List.length data);
         let ty_cst =
           match ty with
@@ -185,17 +195,18 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
             let cst = const_real_of_float (Float.of_string d) in
             Term.t_const cst ty_cst)
         in
-        let minus =
-          (* TODO: generalize wrt the type of constants [csts]. *)
-          let { env; _ } = CRE.user_env engine in
-          let th = Env.read_theory env [ "ieee_float" ] "Float64" in
-          Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ])
-        in
+        let { env; _ } = CRE.user_env engine in
         let args =
+          let minus =
+            (* TODO: generalize wrt the type of constants [csts]. *)
+            let th = Env.read_theory env [ "ieee_float" ] "Float64" in
+            Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ])
+          in
           List.map2_exn tl1 csts ~f:(fun tl c ->
             (Term.t_app_infer minus [ tl; c ], ty_cst))
         in
-        term (term_of_caisar_op ~args engine (Vector n) ty)
+        let caisar_op = Vector (Language.create_vector env n) in
+        term (term_of_caisar_op ~args engine caisar_op ty)
       | _ -> assert false)
     | [ Term t1; Term t2 ] ->
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
@@ -215,14 +226,19 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       assert (Term.t_is_lambda t2);
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
       match caisar_op_of_ls engine ls1 with
-      | Vector n ->
+      | Vector v ->
+        let n = Language.lookup_vector v in
         assert (List.length tl1 = n);
         let args =
           List.mapi tl1 ~f:(fun idx t ->
             let idx = Term.t_int_const (BigInt.of_int idx) in
             (Term.t_func_app_beta_l t2 [ idx; t ], Option.value_exn t.t_ty))
         in
-        Eval (term_of_caisar_op ~args engine (Vector n) ty)
+        let caisar_op =
+          let { env; _ } = CRE.user_env engine in
+          Vector (Language.create_vector env n)
+        in
+        Eval (term_of_caisar_op ~args engine caisar_op ty)
       | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv))
       | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
     | [ Term t1; Term t2 ] ->
@@ -399,7 +415,9 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
         | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } ->
           let caisar_op, id =
             if String.equal ts_name.id_string "vector"
-            then (Vector n, "caisar_v")
+            then
+              let { env; _ } = CRE.user_env engine in
+              (Vector (Language.create_vector env n), "caisar_v")
             else if String.equal ts_name.id_string "tensor"
             then (Tensor n, "caisar_t")
             else assert false
@@ -440,8 +458,12 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
           Term.create_vsymbol preid ty)
       in
       let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in
+      let caisar_op =
+        let { env; _ } = CRE.user_env engine in
+        Vector (Language.create_vector env n)
+      in
       let substitutions =
-        [ term_of_caisar_op ~args engine (Vector n) (Some vs.vs_ty) ]
+        [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ]
       in
       Some { new_quant; substitutions }
   | Tapp
diff --git a/src/language.ml b/src/language.ml
index 88d81e5..8899097 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -155,3 +155,28 @@ let register_onnx_support () =
 let register_ovo_support () =
   Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
     (fun env _ filename _ -> ovo_parser env filename)
+
+type vector = Term.lsymbol
+
+let vectors = Term.Hls.create 10
+
+let create_vector =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module Int) in
+    let ty_elt =
+      let th = Env.read_theory env [ "ieee_float" ] "Float64" in
+      Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
+    in
+    let ty =
+      let th = Env.read_theory env [ "interpretation" ] "Vector" in
+      Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ]
+    in
+    Hashtbl.findi_or_add h ~default:(fun length ->
+      let id = Ident.id_fresh "vector" in
+      let ls =
+        Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty
+      in
+      Term.Hls.add vectors ls length;
+      ls))
+
+let lookup_vector = Term.Hls.find vectors
diff --git a/src/language.mli b/src/language.mli
index eb025a8..a361fd4 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -62,3 +62,8 @@ val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
 val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
 (* [ovo_parser env filename] parses and creates the theories corresponding to
    the given ovo [filename]. The result is memoized. *)
+
+type vector = Term.lsymbol
+
+val create_vector : Env.env -> int -> vector
+val lookup_vector : vector -> int
diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t
index 4439abb..253cde0 100644
--- a/tests/interpretation_acasxu.t
+++ b/tests/interpretation_acasxu.t
@@ -126,15 +126,14 @@ Test interpret on acasxu
         le
         (add RNE
          (mul RNE
-          (caisar_op
-           %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+          (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
           [0] (373.9499200000000200816430151462554931640625:t))
          (7.51888402010059753166615337249822914600372314453125:t))
         (1500.0:t)
   caisar_op,
   (Interpretation.Classifier
      "$TESTCASE_ROOT/TestNetwork.nnet")
-  caisar_op1,
+  vector,
   (Interpretation.Vector 5)
   P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
         (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\
@@ -186,35 +185,35 @@ Test interpret on acasxu
          le (900.0:t) (add RNE (mul RNE caisar_v3 (1100.0:t)) (650.0:t)) ->
          le (960.0:t) (add RNE (mul RNE caisar_v4 (1200.0:t)) (600.0:t))) ->
         not (((lt
-               (caisar_op2
-                %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+               (caisar_op1
+                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [0]
-               (caisar_op2
-                %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+               (caisar_op1
+                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [1] /\
                lt
-               (caisar_op2
-                %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+               (caisar_op1
+                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [0]
-               (caisar_op2
-                %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+               (caisar_op1
+                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [2]) /\
               lt
-              (caisar_op2
-               %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+              (caisar_op1
+               %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
               [0]
-              (caisar_op2
-               %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+              (caisar_op1
+               %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
               [3]) /\
              lt
-             (caisar_op2
-              %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+             (caisar_op1
+              %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
              [0]
-             (caisar_op2
-              %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+             (caisar_op1
+              %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
              [4])
-  caisar_op2,
+  caisar_op1,
   (Interpretation.Classifier
      "$TESTCASE_ROOT/TestNetwork.nnet")
-  caisar_op3,
+  vector,
   (Interpretation.Vector 5)
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index 75b13c7..4336ff2 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -78,18 +78,14 @@ Test interpret on dataset
          (0.776470588000000017103729987866245210170745849609375:t))
         (0.375:t) ->
         lt
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0]
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1] /\
         lt
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [2]
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1]) /\
       (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
         ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\
@@ -123,30 +119,26 @@ Test interpret on dataset
          (0.78431372499999996161790249971090815961360931396484375:t))
         (0.375:t) ->
         lt
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1]
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0] /\
         lt
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [2]
-        (caisar_op
-         %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0])
-  caisar_op2,
+  caisar_op1,
   (Interpretation.Data
      (Interpretation.D_csv
         ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-  caisar_op3,
+  vector, (Interpretation.Vector 5)
+  caisar_op2,
   (Interpretation.Data
      (Interpretation.D_csv
         ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
-  caisar_op1, (Interpretation.Vector 5)
   caisar_op,
   (Interpretation.Classifier
      "$TESTCASE_ROOT/TestNetwork.nnet")
-  caisar_op4,
+  caisar_op3,
   (Interpretation.Dataset <csv>)
-- 
GitLab