From dfd1c276e82e65e683b0d869a4ef3b4d16ae1102 Mon Sep 17 00:00:00 2001
From: hra687261 <hichem.ait-el-hara@ocamlpro.com>
Date: Thu, 24 Nov 2022 08:00:34 +0100
Subject: [PATCH] [Array] Changes to the map rules

---
 colibri2/theories/array/array.ml | 216 ++++++++++++++++++-------------
 1 file changed, 124 insertions(+), 92 deletions(-)

diff --git a/colibri2/theories/array/array.ml b/colibri2/theories/array/array.ml
index f318eb324..0aa3d17d6 100644
--- a/colibri2/theories/array/array.ml
+++ b/colibri2/theories/array/array.ml
@@ -196,13 +196,13 @@ let is_nonlinear env n =
   Debug.dprintf3 debug "is_nonlinear %a: %b" Node.pp n res;
   res
 
-let apply_diff a b = apply_cst Builtin.array_diff [] [ a; b ]
-let apply_const v = apply_cst Builtin.array_const [] [ v ]
-let apply_def_index a = apply_cst Builtin.array_default_index [] [ a ]
-let apply_def_value a = apply_cst Builtin.array_default_value [] [ a ]
+let apply_array_diff a b = apply_cst Builtin.array_diff [] [ a; b ]
+let apply_array_const v = apply_cst Builtin.array_const [] [ v ]
+let apply_array_def_index a = apply_cst Builtin.array_default_index [] [ a ]
+let apply_array_def_value a = apply_cst Builtin.array_default_value [] [ a ]
 
 let distinct_arrays a b =
-  let diff = apply_diff a b in
+  let diff = apply_array_diff a b in
   Expr.Term.distinct [ mk_select_term a diff; mk_select_term b diff ]
 
 let apply_ext env a b =
@@ -224,13 +224,13 @@ let apply_ext env a b =
            Expr.Term.eq ta tb;
            Expr.Term.distinct
              [
-               Expr.Term.Array.select ta
-                 (apply_cst Builtin.array_diff [] [ ta; tb ]);
-               Expr.Term.Array.select tb
-                 (apply_cst Builtin.array_diff [] [ ta; tb ]);
+               Expr.Term.Array.select ta (apply_array_diff ta tb);
+               Expr.Term.Array.select tb (apply_array_diff ta tb);
              ];
          ])
   in
+  Debug.dprintf4 debug "Application of the extensionality rule on %a and %a"
+    Node.pp a Node.pp b;
   Egraph.register env n;
   Boolean.set_true env n
 
@@ -256,8 +256,8 @@ module Theory = struct
   let tb = term_of_var vb
 
   let distinct_term_node ~subst env ta tb =
-    let diff_term = apply_diff ta tb in
-    let diff_eq = Expr.Term.eq diff_term (apply_diff tb ta) in
+    let diff_term = apply_array_diff ta tb in
+    let diff_eq = Expr.Term.eq diff_term (apply_array_diff tb ta) in
     let diff_eq_node = convert ~subst env diff_eq in
     Egraph.register env diff_eq_node;
     Boolean.set_true env diff_eq_node;
@@ -351,26 +351,19 @@ module Theory = struct
     in
     (raup_pattern, raup_run)
 
-  (* ⇑ᵣ: a ≡ K(v), a[j] |> a[j] = v *)
-  let const_pattern, const_run =
-    let term = apply_const tv in
-    let const_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty term in
-    let const_run env subst =
-      Debug.dprintf2 debug "Found const1 with %a" Ground.Subst.pp subst;
-      let n = convert ~subst env term in
-      Egraph.register env n;
-      let term_bis = mk_select_term term tj in
-      let const_pattern_bis = Pattern.of_term_exn ~subst term_bis in
-      let const_run_bis env subst_bis =
-        Debug.dprintf2 debug "Found const2 with %a" Ground.Subst.pp subst;
-        let subst = Ground.Subst.distinct_union subst_bis subst in
-        let v = convert ~subst env (Expr.Term.eq term_bis tv) in
-        Egraph.register env v;
-        Boolean.set_true env v
-      in
-      InvertedPath.add_callback env const_pattern_bis const_run_bis
+  (* K⇓: a = K(v), a[j] |> a[j] = v *)
+  let const_read_pattern, const_read_run =
+    let term = mk_select_term (apply_array_const tv) tj in
+    let const_read_pattern =
+      Pattern.of_term_exn ~subst:Ground.Subst.empty term
     in
-    (const_pattern, const_run)
+    let const_read_run env subst =
+      Debug.dprintf2 debug "Found const_read with %a" Ground.Subst.pp subst;
+      let v = convert ~subst env (Expr.Term.eq term tv) in
+      Egraph.register env v;
+      Boolean.set_true env v
+    in
+    (const_read_pattern, const_read_run)
 
   let apply_res_ext_1 env s =
     let rec aux terms_acc tsubsts_acc vtnodes =
@@ -393,6 +386,7 @@ module Theory = struct
           (* TODO: fix *)
           | { app = { builtin = Expr.Array; _ }; args = [ ind_gty; val_gty ] }
             ->
+              Debug.dprintf2 debug "Found unequal arrays (%a)" Node.S.pp s;
               let tyvt = Expr.Ty.array ind_ty val_ty in
               (* all the arrays are supposed to have the same type, right? *)
               let a0v = Expr.Term.Var.mk "a0" tyvt in
@@ -415,6 +409,8 @@ module Theory = struct
                   [ (ind_ty_var, ind_gty); (val_ty_var, val_gty) ]
               in
               let n = convert ~subst env (Expr.Term._and terms) in
+              Debug.dprintf2 debug "Application of the res-ext-1 rule on (%a)"
+                Node.S.pp s;
               Egraph.register env n;
               Boolean.set_true env n
           | _ -> ())
@@ -424,14 +420,22 @@ module Theory = struct
     let module HT = Datastructure.Hashtbl (Ground.Ty) in
     (* Is this reliable? *)
     let foreign_array_db = HT.create Node.S.pp "foreign_array_db" in
-    fun aty a ->
-      HT.change
-        (function
-          | Some arr_set ->
-              Node.S.iter (apply_ext env a) arr_set;
-              Some (Node.S.add a arr_set)
-          | None -> Some (Node.S.singleton a))
-        foreign_array_db env aty
+    fun (aty : Ground.Ty.t) a ->
+      match aty with
+      | { app = { builtin = Expr.Array; _ }; _ } ->
+          HT.change
+            (fun opt ->
+              match opt with
+              | Some arr_set ->
+                  Debug.dprintf2 debug
+                    "Found new foreign array (%a) on which to apply \
+                     new_foreign_array the hook"
+                    Node.pp a;
+                  Node.S.iter (apply_ext env a) arr_set;
+                  Some (Node.S.add a arr_set)
+              | None -> Some (Node.S.singleton a))
+            foreign_array_db env aty
+      | _ -> ()
 
   let init env =
     (* extáµ£ (restricted extensionality):
@@ -450,7 +454,8 @@ module Theory = struct
       else (aup_pattern, aup_run) :: l
     in
     let l =
-      if Options.get env extended_comb then (const_pattern, const_run) :: l
+      if Options.get env extended_comb then
+        (const_read_pattern, const_read_run) :: l
       else l
     in
     List.iter (fun (p, r) -> InvertedPath.add_callback env p r) l
@@ -479,8 +484,8 @@ module Theory = struct
             [ (va, Ground.node f) ]
             [ (ind_ty_var, ind_gty); (val_ty_var, val_gty) ]
         in
-        let epsilon_app = apply_def_index ta in
-        let delta_app = apply_def_value ta in
+        let epsilon_app = apply_array_def_index ta in
+        let delta_app = apply_array_def_value ta in
         let epsilon_delta_eq =
           Expr.Term.eq (mk_select_term ta epsilon_app) delta_app
         in
@@ -489,12 +494,12 @@ module Theory = struct
         Boolean.set_true env n)
 
   let mk_vt_list pref n ty =
-    let rec aux tl n =
-      if n <= 0 then List.rev tl
+    let rec aux l n =
+      if n <= 0 then List.rev l
       else
         let v = Expr.Term.Var.mk (Format.sprintf "%s%n" pref n) ty in
         let t = Expr.Term.of_var v in
-        aux (t :: tl) (n - 1)
+        aux (t :: l) (n - 1)
     in
     aux [] n
 
@@ -518,6 +523,31 @@ module Theory = struct
     in
     (map_read_pattern, map_read_run)
 
+  (** Returns a set of nodes, which represents the indexes of the array reads
+      that are applied on the array terms present in [bjl]. *)
+  let rec read_index_nodes acc env subst (bjl : Expr.Term.t list) =
+    match bjl with
+    | h :: t -> (
+        let n = convert ~subst env h in
+        match Egraph.get_dom env Info.dom n with
+        | Some v -> (
+            match
+              F_Pos.M.find_opt
+                { f = Expr.Term.Const.Array.select; pos = 0 }
+                v.parents
+            with
+            | Some s ->
+                Ground.S.fold
+                  (fun g acc ->
+                    match Ground.sem g with
+                    | { app = { builtin = Expr.Select; _ }; args } ->
+                        Node.S.add (IArray.get args 1) acc
+                    | _ -> acc)
+                  s acc
+            | None -> read_index_nodes acc env subst t)
+        | None -> read_index_nodes acc env subst t)
+    | [] -> acc
+
   (* map⇑: a = map(f, b1, ..., bn), bk[j] |> a[j] = f(b1[j], ..., bn[j]) *)
   (* map𝛿: a = map(f, b1, ..., bn)        |> 𝛿(a) = f(𝛿(b1), ..., 𝛿((bn)) *)
   let map_aup map_term f_term bitl =
@@ -525,30 +555,28 @@ module Theory = struct
     let map_run env subst =
       Debug.dprintf2 debug "Found array_map(f,b1, ..., bn) with %a"
         Ground.Subst.pp subst;
+      (* map⇑ *)
       if Options.get env extended_comb then (
         Debug.dprintf0 debug "Application of the map_aup rule";
-        let bkjl = List.map (fun bi -> mk_select_term bi tj) bitl in
-        let bkjl_patterns = List.map (Pattern.of_term_exn ~subst) bkjl in
-        let bkj_run =
-          let seen = ref false in
-          (* TODO: find one k for each j that is encountered, if there are
-             different k's for one j, only one of them needs to be matched. *)
-          fun env subst ->
-            if not !seen then (
-              seen := true;
-              let term =
-                Expr.Term.eq
-                  (mk_select_term map_term tj)
-                  (Expr.Term.apply f_term []
-                     (List.map (fun bi -> mk_select_term bi tj) bitl))
-              in
-              let n = convert ~subst env term in
-              Egraph.register env n;
-              Boolean.set_true env n)
-        in
-        List.iter
-          (fun pattern -> InvertedPath.add_callback env pattern bkj_run)
-          bkjl_patterns);
+        Node.S.iter
+          (fun ind ->
+            let term =
+              Expr.Term.eq
+                (mk_select_term map_term tj)
+                (Expr.Term.apply f_term []
+                   (List.map (fun bi -> mk_select_term bi tj) bitl))
+            in
+            let n =
+              convert
+                ~subst:
+                  (Ground.Subst.distinct_union subst
+                     (mk_subst [ (vj, ind) ] []))
+                env term
+            in
+            Egraph.register env n;
+            Boolean.set_true env n)
+          (read_index_nodes Node.S.empty env subst bitl));
+      (* map𝛿 *)
       if Options.get env default_values then (
         Debug.dprintf0 debug "Application of the map_delta rule";
         let d_a = apply_cst Builtin.array_default_value [] [ ta ] in
@@ -564,29 +592,28 @@ module Theory = struct
     in
     (map_pattern, map_run)
 
-  module NM = Datastructure.Memo2 (Popop_stdlib.DInt)
-
-  let new_map_db =
-    NM.create Fmt.nop "new_map_db" (fun _ env f_arity ->
-        let b_ty = Expr.Ty.array ind_ty alpha_ty in
-        let f_ty = Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty in
-        let bitl = mk_vt_list "b" f_arity b_ty in
-        let f_var = Expr.Term.Var.mk "f" f_ty in
-        let f_term = Expr.Term.of_var f_var in
-        let map_term =
-          apply_cst (Builtin.array_map f_arity) [] (f_term :: bitl)
-        in
-        let map_pattern, map_run = map_adowm map_term f_term bitl in
-        let map_read_pattern, map_read_run = map_adowm map_term f_term bitl in
-        if Options.get env extended_comb then
-          InvertedPath.add_callback env map_pattern map_run;
-        InvertedPath.add_callback env map_read_pattern map_read_run)
-
-  let new_map env mapf_t =
-    (* Does the type matter? should it be cached? *)
-    let mapf_s = Ground.sem mapf_t in
-    let f_arity = IArray.length mapf_s.args - 1 in
-    NM.find new_map_db env f_arity
+  let new_map =
+    let module NM = Datastructure.Memo2 (Popop_stdlib.DInt) in
+    let new_map_db =
+      NM.create Fmt.nop "new_map_db" (fun _ env f_arity ->
+          let b_ty = Expr.Ty.array ind_ty alpha_ty in
+          let f_ty = Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty in
+          let bitl = mk_vt_list "b" f_arity b_ty in
+          let f_var = Expr.Term.Var.mk "f" f_ty in
+          let f_term = Expr.Term.of_var f_var in
+          let map_term =
+            apply_cst (Builtin.array_map f_arity) [] (f_term :: bitl)
+          in
+          let map_pattern, map_run = map_aup map_term f_term bitl in
+          let map_read_pattern, map_read_run = map_adowm map_term f_term bitl in
+          if Options.get env extended_comb then
+            InvertedPath.add_callback env map_read_pattern map_read_run;
+          InvertedPath.add_callback env map_pattern map_run)
+    in
+    fun env mapf_t ->
+      let mapf_s = Ground.sem mapf_t in
+      let f_arity = IArray.length mapf_s.args - 1 in
+      NM.find new_map_db env f_arity
 end
 
 let converter env (f : Ground.t) =
@@ -651,10 +678,11 @@ let converter env (f : Ground.t) =
             [ (Theory.va, a); (Theory.vi, i) ]
             [ (ind_ty_var, ind_gty); (val_ty_var, val_gty) ]
         in
-        let eps_term = apply_def_index Theory.ta in
+        let eps_term = apply_array_def_index Theory.ta in
         let eps_node = convert ~subst env eps_term in
         Egraph.register env eps_node;
         if not (Node.equal i eps_node) then (
+          (* is this the right equality test? *)
           Debug.incr array_num;
           let i_eps_neq_node =
             convert ~subst env (Expr.Term.neq eps_term Theory.ti)
@@ -690,7 +718,9 @@ let converter env (f : Ground.t) =
       (* application of the `U𝛿` rule *)
       if Options.get env default_values then (
         let eq_term =
-          Expr.Term.eq (apply_def_value store_term) (apply_def_value Theory.ta)
+          Expr.Term.eq
+            (apply_array_def_value store_term)
+            (apply_array_def_value Theory.ta)
         in
         let eq_node = convert ~subst env eq_term in
         Egraph.register env eq_node;
@@ -712,7 +742,9 @@ let converter env (f : Ground.t) =
         let subst = mk_subst [ (Theory.vv, v) ] [ (val_ty_var, val_gty) ] in
         let eq_node =
           convert ~subst env
-            (Expr.Term.eq (apply_def_value (apply_const Theory.tv)) Theory.tv)
+            (Expr.Term.eq
+               (apply_array_def_value (apply_array_const Theory.tv))
+               Theory.tv)
         in
         Egraph.register env eq_node;
         Boolean.set_true env eq_node)
-- 
GitLab