From ea2f7d9c4f657fe8aad1275fb43e53c85fcb82ad Mon Sep 17 00:00:00 2001
From: hra687261 <hichem.ait-el-hara@ocamlpro.com>
Date: Thu, 17 Nov 2022 10:16:29 +0100
Subject: [PATCH] Added map inference rules

---
 colibri2/theories/array/array.ml | 207 ++++++++++++++++++++++++++-----
 1 file changed, 174 insertions(+), 33 deletions(-)

diff --git a/colibri2/theories/array/array.ml b/colibri2/theories/array/array.ml
index a41d097e1..4a753b9b2 100644
--- a/colibri2/theories/array/array.ml
+++ b/colibri2/theories/array/array.ml
@@ -30,20 +30,17 @@ let restrict_ext =
 let restrict_aup =
   Colibri2_stdlib.Flags.Solve.gen_register_flag
     ~rep:(fun flag -> (flag, "Restrict the ⇑ rule"))
-    "res-aup"
+    "Array.res-aup"
 
 let extended_comb =
   Colibri2_stdlib.Flags.Solve.gen_register_flag
     ~rep:(fun flag -> (flag, "Extended combinators"))
-    "ext-comb"
+    "Array.ext-comb"
 
 let default_values =
   Colibri2_stdlib.Flags.Solve.gen_register_flag
     ~rep:(fun flag -> (flag, "Default values"))
-    "def-values"
-
-let db = Datastructure.Push.create Ground.pp "Array.db"
-(* Use one db per array type? *)
+    "Array.def-values"
 
 let debug = Debug.register_info_flag ~desc:"For array theory" "Array"
 let stats = Debug.register_stats_int "Array.rule"
@@ -58,14 +55,18 @@ let ind_ty_var = Expr.Ty.Var.mk "ind_ty"
 let val_ty_var = Expr.Ty.Var.mk "val_ty"
 let ind_ty = Expr.Ty.of_var ind_ty_var
 let val_ty = Expr.Ty.of_var val_ty_var
+let alpha_ty_var = Expr.Ty.Var.mk "alpha"
+let alpha_ty = Expr.Ty.of_var alpha_ty_var
 let array_ty = Expr.Ty.array ind_ty val_ty
 let bind_tys tyvl ty = Expr.Ty.pi tyvl ty
+let replicate n v = List.init n (fun _ -> v)
 
 (* Builtins *)
 module Builtin = struct
   type _ Expr.t +=
     | Array_diff
     | Array_const
+    | Array_map
     | Array_default_index
     | Array_default_value
 
@@ -79,6 +80,26 @@ module Builtin = struct
       (Dolmen_std.Path.global "array_const")
       (Expr.Ty.arrow [ val_ty ] array_ty)
 
+  let array_map : int -> Dolmen_std.Expr.term_cst =
+    let cache = Popop_stdlib.DInt.H.create 13 in
+    let get_ty i =
+      match Popop_stdlib.DInt.H.find cache i with
+      | ty -> ty
+      | exception Not_found ->
+          let ty =
+            Expr.Ty.arrow
+              (Expr.Ty.arrow (replicate i alpha_ty) val_ty
+              :: replicate i (Expr.Ty.array ind_ty alpha_ty))
+              val_ty
+          in
+          Popop_stdlib.DInt.H.add cache i ty;
+          ty
+    in
+    fun i ->
+      Expr.Id.mk ~name:"array_map" ~builtin:Array_map
+        (Dolmen_std.Path.global "array_map")
+        (get_ty i)
+
   let array_default_index : Dolmen_std.Expr.term_cst =
     Expr.Id.mk ~name:"array_default_index" ~builtin:Array_default_index
       (Dolmen_std.Path.global "array_default_index")
@@ -116,6 +137,19 @@ module Builtin = struct
         | Dolmen_loop.Typer.T.Id
             { ns = Term; name = Simple "array_default_value" } ->
             app1 env s array_default_value
+        | Dolmen_loop.Typer.T.Id { ns = Term; name = Simple "array_map" } ->
+            `Term
+              (Dolmen_type.Base.term_app_list
+                 (module Dolmen_loop.Typer.T)
+                 env s
+                 (function
+                   | _ :: t as l ->
+                       (* "t" should probably never be empty... *)
+                       Expr.Term.apply_cst (array_map (List.length t)) [] l
+                   | _ ->
+                       failwith
+                         "array_map needs to be applied to a function and n \
+                          arrays"))
         | _ -> `Not_found)
 end
 
@@ -328,35 +362,139 @@ module Theory = struct
     let l = if !extended_comb then (const_pattern, const_run) :: l else l in
     List.iter (fun (p, r) -> InvertedPath.add_callback env p r) l
 
-  let new_array env s_index_ty s_value_ty f =
-    (* Extensionality rule ext: a, b ⇒ (a = b) ⋁ (a[k] ≠ b[k]) *)
-    if not !restrict_ext then (
-      Datastructure.Push.iter db env ~f:(fun f2 ->
-          let subst =
-            mk_subst
-              [ (va, Ground.node f2); (vb, Ground.node f) ]
-              [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ]
-          in
-          Debug.dprintf2 debug "Found ext with %a" Ground.Subst.pp subst;
-          let n = distinct_term_node ~subst env ta tb in
-          Egraph.register env n;
-          Boolean.set_true env n);
-      Datastructure.Push.push db env f);
-    (* 𝝐𝛿: a |> a[𝝐] = 𝛿 *)
-    if !extended_comb then (
-      let subst =
-        mk_subst
-          [ (va, Ground.node f) ]
-          [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ]
-      in
-      let epsilon_app = apply_def_index ta in
-      let delta_app = apply_def_value ta in
-      let epsilon_delta_eq =
-        Expr.Term.eq (Expr.Term.select ta epsilon_app) delta_app
+  let new_array =
+    let db = Datastructure.Push.create Ground.pp "Array.db" in
+    fun env s_index_ty s_value_ty f ->
+      (* Extensionality rule ext: a, b ⇒ (a = b) ⋁ (a[k] ≠ b[k]) *)
+      if not !restrict_ext then (
+        Datastructure.Push.iter db env ~f:(fun f2 ->
+            let subst =
+              mk_subst
+                [ (va, Ground.node f2); (vb, Ground.node f) ]
+                [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ]
+            in
+            Debug.dprintf2 debug "Found ext with %a" Ground.Subst.pp subst;
+            let n = distinct_term_node ~subst env ta tb in
+            Egraph.register env n;
+            Boolean.set_true env n);
+        Datastructure.Push.push db env f);
+      (* 𝝐𝛿: a |> a[𝝐] = 𝛿 *)
+      if !extended_comb then (
+        Debug.dprintf0 debug "Application of the epsilon_delta rule";
+        let subst =
+          mk_subst
+            [ (va, Ground.node f) ]
+            [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ]
+        in
+        let epsilon_app = apply_def_index ta in
+        let delta_app = apply_def_value ta in
+        let epsilon_delta_eq =
+          Expr.Term.eq (Expr.Term.select ta epsilon_app) delta_app
+        in
+        let n = convert ~subst env epsilon_delta_eq in
+        Egraph.register env n;
+        Boolean.set_true env n)
+
+  let mk_vt_list pref n ty =
+    let rec aux tl n =
+      if n <= 0 then List.rev tl
+      else
+        let v = Expr.Term.Var.mk (Format.sprintf "%s%n" pref n) ty in
+        let t = Expr.Term.of_var v in
+        aux (t :: tl) (n - 1)
+    in
+    aux [] n
+
+  (* map⇓: a = map(f, b1, ..., bn), a[j] |> a[j] = f(b1[j], ..., bn[j]) *)
+  let map_adowm map_term f_term bitl =
+    let map_read_pattern =
+      Pattern.of_term_exn ~subst:Ground.Subst.empty
+        (Expr.Term.select map_term tj)
+    in
+    let map_read_run env subst =
+      Debug.dprintf2 debug "Found array_map(f,b1, ..., bn)[j] with %a"
+        Ground.Subst.pp subst;
+      let term =
+        Expr.Term.eq
+          (Expr.Term.select map_term tj)
+          (Expr.Term.apply f_term []
+             (List.map (fun bi -> Expr.Term.select bi tj) bitl))
       in
-      let n = convert ~subst env epsilon_delta_eq in
+      let n = convert ~subst env term in
       Egraph.register env n;
-      Boolean.set_true env n)
+      Boolean.set_true env n
+    in
+    (map_read_pattern, map_read_run)
+
+  (* map⇑: a = map(f, b1, ..., bn), bk[j] |> a[j] = f(b1[j], ..., bn[j]) *)
+  (* map𝛿: a = map(f, b1, ..., bn)        |> 𝛿(a) = f(𝛿(b1), ..., 𝛿((bn)) *)
+  let map_aup map_term f_term bitl =
+    let map_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty map_term in
+    let map_run env subst =
+      Debug.dprintf2 debug "Found array_map(f,b1, ..., bn) with %a"
+        Ground.Subst.pp subst;
+      if !extended_comb then (
+        Debug.dprintf0 debug "Application of the map_aup rule";
+        let bkjl = List.map (fun bi -> Expr.Term.select bi tj) bitl in
+        let bkjl_patterns = List.map (Pattern.of_term_exn ~subst) bkjl in
+        let bkj_run =
+          let seen = ref false in
+          (* TODO: find one k for each j that is encountered, if there are
+             different k's for one j, only one of them needs to be matched. *)
+          fun env subst ->
+            if not !seen then (
+              seen := true;
+              let term =
+                Expr.Term.eq
+                  (Expr.Term.select map_term tj)
+                  (Expr.Term.apply f_term []
+                     (List.map (fun bi -> Expr.Term.select bi tj) bitl))
+              in
+              let n = convert ~subst env term in
+              Egraph.register env n;
+              Boolean.set_true env n)
+        in
+        List.iter
+          (fun pattern -> InvertedPath.add_callback env pattern bkj_run)
+          bkjl_patterns);
+      if !default_values then (
+        Debug.dprintf0 debug "Application of the map_delta rule";
+        let d_a = Expr.Term.apply_cst Builtin.array_default_value [] [ ta ] in
+        let d_bil =
+          List.map
+            (fun bi ->
+              Expr.Term.apply_cst Builtin.array_default_value [] [ bi ])
+            bitl
+        in
+        let term = Expr.Term.eq d_a (Expr.Term.apply f_term [] d_bil) in
+        let n = convert ~subst env term in
+        Egraph.register env n;
+        Boolean.set_true env n)
+    in
+    (map_pattern, map_run)
+
+  let new_map =
+    (* Does the type matter? should it be cached? *)
+    let cache = ref Popop_stdlib.DInt.S.empty in
+    let seen i = Popop_stdlib.DInt.S.mem i !cache in
+    let add i = cache := Popop_stdlib.DInt.S.add i !cache in
+    fun env mapf_t ->
+      let mapf_s = Ground.sem mapf_t in
+      let f_arity = IArray.length mapf_s.args - 1 in
+      if not (seen f_arity) then (
+        add f_arity;
+        let b_ty = Expr.Ty.array ind_ty alpha_ty in
+        let f_ty = Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty in
+        let bitl = mk_vt_list "b" f_arity b_ty in
+        let f_var = Expr.Term.Var.mk "f" f_ty in
+        let f_term = Expr.Term.of_var f_var in
+        let map_term =
+          Expr.Term.apply_cst (Builtin.array_map f_arity) [] (f_term :: bitl)
+        in
+        let map_pattern, map_run = map_adowm map_term f_term bitl in
+        let map_read_pattern, map_read_run = map_adowm map_term f_term bitl in
+        if !extended_comb then InvertedPath.add_callback env map_pattern map_run;
+        InvertedPath.add_callback env map_read_pattern map_read_run)
 end
 
 let converter env (f : Ground.t) =
@@ -486,6 +624,9 @@ let converter env (f : Ground.t) =
         in
         Egraph.register env eq_node;
         Boolean.set_true env eq_node)
+  | { app = { builtin = Builtin.Array_map; _ }; _ }
+    when !extended_comb || !default_values ->
+      Theory.new_map env f
   | _ -> ()
 
 let init env : unit =
-- 
GitLab