From 056214330c30c89619f21f407094788a1fbebc7d Mon Sep 17 00:00:00 2001
From: hra687261 <hichem.ait-el-hara@ocamlpro.com>
Date: Tue, 29 Nov 2022 13:25:44 +0100
Subject: [PATCH] [Array] Added Array_dom and used it for map rules

---
 colibri2/core/colibri2_core.mli         |   5 +-
 colibri2/core/datastructure.ml          |   8 +-
 colibri2/core/datastructure.mli         |   2 +-
 colibri2/theories/array/array.ml        | 463 ++++++++++++++----------
 colibri2/theories/array/array_dom.ml    | 104 ++++++
 colibri2/theories/array/array_dom.mli   |  53 +++
 colibri2/theories/array/foreign_dom.ml  |  10 +-
 colibri2/theories/array/foreign_dom.mli |   6 +-
 8 files changed, 438 insertions(+), 213 deletions(-)
 create mode 100644 colibri2/theories/array/array_dom.ml
 create mode 100644 colibri2/theories/array/array_dom.mli

diff --git a/colibri2/core/colibri2_core.mli b/colibri2/core/colibri2_core.mli
index 2889a8ac0..c59e17d19 100644
--- a/colibri2/core/colibri2_core.mli
+++ b/colibri2/core/colibri2_core.mli
@@ -756,10 +756,7 @@ module Datastructure : sig
     type key
 
     val create :
-      'a Format.printer ->
-      string ->
-      (Context.creator -> 'b Egraph.t -> key -> 'a) ->
-      ('a, 'b) t
+      'a Format.printer -> string -> ('b Egraph.t -> key -> 'a) -> ('a, 'b) t
 
     val find : ('a, 'b) t -> 'b Egraph.t -> key -> 'a
     val iter : (key -> 'a -> unit) -> ('a, 'b) t -> 'b Egraph.t -> unit
diff --git a/colibri2/core/datastructure.ml b/colibri2/core/datastructure.ml
index ac623d67b..d25ebd89c 100644
--- a/colibri2/core/datastructure.ml
+++ b/colibri2/core/datastructure.ml
@@ -210,7 +210,7 @@ module type Memo2 =  sig
   type ('a, 'b) t
   type key
 
-  val create: 'a Format.printer -> string -> (Context.creator -> 'b Egraph.t -> key -> 'a) -> ('a, 'b) t
+  val create: 'a Format.printer -> string -> ('b Egraph.t -> key -> 'a) -> ('a, 'b) t
   val find : ('a, 'b) t -> 'b Egraph.t -> key -> 'a
   val iter : (key -> 'a -> unit) -> ('a, 'b) t -> 'b Egraph.t -> unit
   val fold : (key -> 'a -> 'acc -> 'acc) -> ('a, 'b) t -> 'b Egraph.t -> 'acc -> 'acc
@@ -220,12 +220,12 @@ module Memo2(S:Colibri2_popop_lib.Popop_stdlib.Datatype) : Memo2 with type key :
 
   type ('a, 'b) t' = {
     h: 'a S.H.t;
-    def: (Context.creator -> 'b Egraph.t -> S.t -> 'a);
+    def: ('b Egraph.t -> S.t -> 'a);
   }
 
   type ('a, 'b) t = ('a, 'b) t' Env.Unsaved.t
 
-  let create : type a b. a Colibri2_popop_lib.Pp.pp -> _ -> (Context.creator  -> b Egraph.t -> S.t -> a) -> (a, b) t = fun pp name def ->
+  let create : type a b. a Colibri2_popop_lib.Pp.pp -> _ -> (b Egraph.t -> S.t -> a) -> (a, b) t = fun pp name def ->
     let module M = struct
       type t = (a, b) t'
       let name = name end
@@ -242,7 +242,7 @@ module Memo2(S:Colibri2_popop_lib.Popop_stdlib.Datatype) : Memo2 with type key :
     match S.H.find_opt h.h k with
       | Some r -> r
       | None ->
-        let r = h.def (Egraph.context d) d k in
+        let r = h.def d k in
         S.H.add h.h k r;
         r
 
diff --git a/colibri2/core/datastructure.mli b/colibri2/core/datastructure.mli
index 3e99b1151..fe53cff20 100644
--- a/colibri2/core/datastructure.mli
+++ b/colibri2/core/datastructure.mli
@@ -72,7 +72,7 @@ module type Memo2 =  sig
   type ('a, 'b) t
   type key
 
-  val create: 'a Format.printer -> string -> (Context.creator -> 'b Egraph.t -> key -> 'a) -> ('a, 'b) t
+  val create: 'a Format.printer -> string -> ('b Egraph.t -> key -> 'a) -> ('a, 'b) t
   val find : ('a, 'b) t -> 'b Egraph.t -> key -> 'a
   val iter : (key -> 'a -> unit) -> ('a, 'b) t -> 'b Egraph.t -> unit
   val fold : (key -> 'a -> 'acc -> 'acc) -> ('a, 'b) t -> 'b Egraph.t -> 'acc -> 'acc
diff --git a/colibri2/theories/array/array.ml b/colibri2/theories/array/array.ml
index 0aa3d17d6..c944b7dbb 100644
--- a/colibri2/theories/array/array.ml
+++ b/colibri2/theories/array/array.ml
@@ -45,8 +45,12 @@ let default_values =
 
 let debug = Debug.register_info_flag ~desc:"For array theory" "Array"
 let stats = Debug.register_stats_int "Array.rule"
-let array_num = Debug.register_stats_int "Array.num"
-(* for now only 1 counter as there is only one index sort (Not used yet)*)
+
+module NHT = Datastructure.Hashtbl (Node)
+module GHT = Datastructure.Hashtbl (Ground)
+module GTHT = Datastructure.Hashtbl (Ground.Ty)
+
+let array_num_db = NHT.create Popop_stdlib.DInt.pp "array_num_db"
 
 let convert ~subst =
   Colibri2_theories_quantifiers.Subst.convert ~subst_old:Ground.Subst.empty
@@ -54,19 +58,28 @@ let convert ~subst =
 
 let ind_ty_var = Expr.Ty.Var.mk "ind_ty"
 let val_ty_var = Expr.Ty.Var.mk "val_ty"
+let alpha_ty_var = Expr.Ty.Var.mk "alpha"
+let a_ty_var = Expr.Ty.Var.mk "a"
+let b_ty_var = Expr.Ty.Var.mk "b"
+let c_ty_var = Expr.Ty.Var.mk "c"
 let ind_ty = Expr.Ty.of_var ind_ty_var
 let val_ty = Expr.Ty.of_var val_ty_var
-let alpha_ty_var = Expr.Ty.Var.mk "alpha"
+let a_ty = Expr.Ty.of_var a_ty_var
+let b_ty = Expr.Ty.of_var b_ty_var
+let c_ty = Expr.Ty.of_var c_ty_var
 let alpha_ty = Expr.Ty.of_var alpha_ty_var
 let array_ty = Expr.Ty.array ind_ty val_ty
-let bind_tys tyvl ty = Expr.Ty.pi tyvl ty
+let array_ty_ab = Expr.Ty.array a_ty b_ty
+let array_ty_ac = Expr.Ty.array a_ty c_ty
+let array_ty_alpha = Expr.Ty.array ind_ty alpha_ty
 let replicate n v = List.init n (fun _ -> v)
 let mk_store_term = Expr.Term.Array.store
 let mk_select_term = Expr.Term.Array.select
 let apply_cst = Expr.Term.apply_cst
 
-let array_ty_args : Expr.ty -> Expr.ty list = function
-  | { ty_descr = TyApp ({ builtin = Expr.Array; _ }, ty_args); _ } -> ty_args
+let array_ty_args : Expr.ty -> Expr.ty * Expr.ty = function
+  | { ty_descr = TyApp ({ builtin = Expr.Array; _ }, [ ind_ty; val_ty ]); _ } ->
+      (ind_ty, val_ty)
   | ty ->
       failwith (Format.asprintf "'%a' is not an array type!" Expr.Ty.print ty)
 
@@ -79,24 +92,30 @@ let array_gty_args : Ground.Ty.t -> Ground.Ty.t * Ground.Ty.t = function
 
 let node_tyl env n = Ground.Ty.S.elements (Ground.tys env n)
 
-(* Builtins *)
 module Builtin = struct
+  (** Additional array Builtins *)
   type _ Expr.t +=
     | Array_diff
-    | Array_const
+          (** [Array_diff: 'a 'b. ('a, 'b) Array -> ('a, 'b) Array -> 'a] *)
+    | Array_const  (** [Array_const: 'b. 'b -> (ind_ty, 'b) Array] *)
     | Array_map
+          (** [Array_map: 'a 'b 'c. (('b -> ... -> 'b -> 'c) -> ('a, 'b) Array -> ... -> ('a, 'b) Array)-> ('a, 'c) Array]  *)
     | Array_default_index
+          (** [Array_default_index: 'a 'b. ('a, 'b) Array -> 'a] *)
     | Array_default_value
+          (** [Array_default_value: 'a 'b. ('a, 'b) Array -> 'b] *)
 
   let array_diff : Dolmen_std.Expr.term_cst =
     Expr.Id.mk ~name:"colibri2_array_diff" ~builtin:Array_diff
       (Dolmen_std.Path.global "colibri2_array_diff")
-      (Expr.Ty.arrow [ array_ty; array_ty ] ind_ty)
+      (Expr.Ty.pi [ a_ty_var; b_ty_var ]
+         (Expr.Ty.arrow [ array_ty_ab; array_ty_ab ] a_ty))
 
   let array_const : Dolmen_std.Expr.term_cst =
     Expr.Id.mk ~name:"colibri2_array_const" ~builtin:Array_const
       (Dolmen_std.Path.global "colibri2_array_const")
-      (Expr.Ty.arrow [ val_ty ] array_ty)
+      (Expr.Ty.pi [ b_ty_var ]
+         (Expr.Ty.arrow [ b_ty ] (Expr.Ty.array ind_ty b_ty)))
 
   let array_map : int -> Dolmen_std.Expr.term_cst =
     let cache = Popop_stdlib.DInt.H.create 13 in
@@ -106,9 +125,8 @@ module Builtin = struct
       | exception Not_found ->
           let ty =
             Expr.Ty.arrow
-              (Expr.Ty.arrow (replicate i alpha_ty) val_ty
-              :: replicate i (Expr.Ty.array ind_ty alpha_ty))
-              val_ty
+              (Expr.Ty.arrow (replicate i b_ty) c_ty :: replicate i array_ty_ab)
+              array_ty_ac
           in
           Popop_stdlib.DInt.H.add cache i ty;
           ty
@@ -116,17 +134,42 @@ module Builtin = struct
     fun i ->
       Expr.Id.mk ~name:"colibri2_array_map" ~builtin:Array_map
         (Dolmen_std.Path.global "colibri2_array_map")
-        (get_ty i)
+        (Expr.Ty.pi [ a_ty_var; b_ty_var; c_ty_var ] (get_ty i))
 
   let array_default_index : Dolmen_std.Expr.term_cst =
     Expr.Id.mk ~name:"colibri2_array_default_index" ~builtin:Array_default_index
       (Dolmen_std.Path.global "colibri2_array_default_index")
-      (Expr.Ty.arrow [ array_ty ] ind_ty)
+      (Expr.Ty.pi [ a_ty_var; b_ty_var ] (Expr.Ty.arrow [ array_ty_ab ] a_ty))
 
   let array_default_value : Dolmen_std.Expr.term_cst =
     Expr.Id.mk ~name:"colibri2_array_default_value" ~builtin:Array_default_value
       (Dolmen_std.Path.global "colibri2_array_default_value")
-      (Expr.Ty.arrow [ array_ty ] val_ty)
+      (Expr.Ty.pi [ a_ty_var; b_ty_var ] (Expr.Ty.arrow [ array_ty_ab ] b_ty))
+
+  let apply_array_diff a b =
+    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
+    apply_cst array_diff [ ind_ty; val_ty ] [ a; b ]
+
+  let apply_array_const v = apply_cst array_const [ v.Expr.term_ty ] [ v ]
+  (* what's the type of the index in constant arrays? *)
+
+  let apply_array_def_index a =
+    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
+    apply_cst array_default_index [ ind_ty; val_ty ] [ a ]
+
+  let apply_array_def_value a =
+    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
+    apply_cst array_default_value [ ind_ty; val_ty ] [ a ]
+
+  let apply_array_map f_arity f_term bitl =
+    match (bitl, f_term) with
+    | h :: _, Expr.{ term_ty = { ty_descr = Arrow (_, ret_ty); _ }; _ } ->
+        let bi_ind_ty, bi_val_ty = array_ty_args h.Expr.term_ty in
+        apply_cst (array_map f_arity)
+          [ bi_ind_ty; bi_val_ty; ret_ty ]
+          (f_term :: bitl)
+    | _, _ ->
+        failwith "array_map needs to be applied to a function and n > 0 arrays"
 
   let () =
     let app1 env s f =
@@ -134,23 +177,28 @@ module Builtin = struct
         (Dolmen_type.Base.term_app1
            (module Dolmen_loop.Typer.T)
            env s
-           (fun a -> apply_cst f [] [ a ]))
-    in
-    let app2 env s f =
-      `Term
-        (Dolmen_type.Base.term_app2
-           (module Dolmen_loop.Typer.T)
-           env s
-           (fun a b -> apply_cst f [] [ a; b ]))
+           (fun a ->
+             let ind_ty, val_ty = array_ty_args a.term_ty in
+             apply_cst f [ ind_ty; val_ty ] [ a ]))
     in
     Expr.add_builtins (fun env s ->
         match s with
         | Dolmen_loop.Typer.T.Id
             { ns = Term; name = Simple "colibri2_array_diff" } ->
-            app2 env s array_diff
+            `Term
+              (Dolmen_type.Base.term_app2
+                 (module Dolmen_loop.Typer.T)
+                 env s
+                 (fun a b ->
+                   let ind_ty, val_ty = array_ty_args a.term_ty in
+                   apply_cst array_diff [ ind_ty; val_ty ] [ a; b ]))
         | Dolmen_loop.Typer.T.Id
             { ns = Term; name = Simple "colibri2_array_const" } ->
-            app1 env s array_const
+            `Term
+              (Dolmen_type.Base.term_app1
+                 (module Dolmen_loop.Typer.T)
+                 env s
+                 (fun a -> apply_cst array_const [ a.term_ty ] [ a ]))
         | Dolmen_loop.Typer.T.Id
             { ns = Term; name = Simple "colibri2_array_default_index" } ->
             app1 env s array_default_index
@@ -164,13 +212,11 @@ module Builtin = struct
                  (module Dolmen_loop.Typer.T)
                  env s
                  (function
-                   | _ :: t as l ->
-                       (* "t" should probably never be empty... *)
-                       apply_cst (array_map (List.length t)) [] l
+                   | f_term :: t -> apply_array_map (List.length t) f_term t
                    | _ ->
                        failwith
-                         "array_map needs to be applied to a function and n \
-                          arrays"))
+                         "array_map needs to be applied to a function and n > \
+                          0 arrays"))
         | _ -> `Not_found)
 end
 
@@ -178,62 +224,10 @@ let mk_subst term_l ty_l =
   Ground.Subst.
     { term = Expr.Term.Var.M.of_list term_l; ty = Expr.Ty.Var.M.of_list ty_l }
 
-let is_foreign env n =
-  let res =
-    match Egraph.get_dom env Foreign_dom.key n with
-    | Some IsForeign -> true
-    | _ -> false
-  in
-  Debug.dprintf3 debug "is_foreign %a: %b" Node.pp n res;
-  res
-
-let is_nonlinear env n =
-  let res =
-    match Egraph.get_dom env Linearity_dom.key n with
-    | Some NonLinear -> true
-    | _ -> false
-  in
-  Debug.dprintf3 debug "is_nonlinear %a: %b" Node.pp n res;
-  res
-
-let apply_array_diff a b = apply_cst Builtin.array_diff [] [ a; b ]
-let apply_array_const v = apply_cst Builtin.array_const [] [ v ]
-let apply_array_def_index a = apply_cst Builtin.array_default_index [] [ a ]
-let apply_array_def_value a = apply_cst Builtin.array_default_value [] [ a ]
-
 let distinct_arrays a b =
-  let diff = apply_array_diff a b in
+  let diff = Builtin.apply_array_diff a b in
   Expr.Term.distinct [ mk_select_term a diff; mk_select_term b diff ]
 
-let apply_ext env a b =
-  let va = Expr.Term.Var.mk "a" array_ty in
-  let vb = Expr.Term.Var.mk "b" array_ty in
-  let ta = Expr.Term.of_var va in
-  let tb = Expr.Term.of_var vb in
-  let gtya, gtyb = array_gty_args (List.hd (node_tyl env a)) in
-  (* TODO: fix *)
-  let n =
-    convert
-      ~subst:
-        (mk_subst
-           [ (va, a); (vb, b) ]
-           [ (ind_ty_var, gtya); (val_ty_var, gtyb) ])
-      env
-      (Expr.Term._or
-         [
-           Expr.Term.eq ta tb;
-           Expr.Term.distinct
-             [
-               Expr.Term.Array.select ta (apply_array_diff ta tb);
-               Expr.Term.Array.select tb (apply_array_diff ta tb);
-             ];
-         ])
-  in
-  Debug.dprintf4 debug "Application of the extensionality rule on %a and %a"
-    Node.pp a Node.pp b;
-  Egraph.register env n;
-  Boolean.set_true env n
-
 (* Generalized, Efficient Array Decision Procedures. de Moura, Bjorner *)
 module Theory = struct
   open Colibri2_theories_quantifiers
@@ -256,8 +250,8 @@ module Theory = struct
   let tb = term_of_var vb
 
   let distinct_term_node ~subst env ta tb =
-    let diff_term = apply_array_diff ta tb in
-    let diff_eq = Expr.Term.eq diff_term (apply_array_diff tb ta) in
+    let diff_term = Builtin.apply_array_diff ta tb in
+    let diff_eq = Expr.Term.eq diff_term (Builtin.apply_array_diff tb ta) in
     let diff_eq_node = convert ~subst env diff_eq in
     Egraph.register env diff_eq_node;
     Boolean.set_true env diff_eq_node;
@@ -335,17 +329,20 @@ module Theory = struct
         Debug.dprintf2 debug "Found raup2 with %a" Ground.Subst.pp subst;
         let subst = Ground.Subst.distinct_union subst_bis subst in
         let bn = convert ~subst env tb in
-        if is_nonlinear env bn then (
-          let v =
-            convert ~subst env
-              (Expr.Term._or
-                 [
-                   Expr.Term.eq ti tj;
-                   Expr.Term.eq (mk_select_term term tj) (mk_select_term tb tj);
-                 ])
-          in
-          Egraph.register env v;
-          Boolean.set_true env v)
+        match Egraph.get_dom env Linearity_dom.key bn with
+        | Some NonLinear ->
+            let v =
+              convert ~subst env
+                (Expr.Term._or
+                   [
+                     Expr.Term.eq ti tj;
+                     Expr.Term.eq (mk_select_term term tj)
+                       (mk_select_term tb tj);
+                   ])
+            in
+            Egraph.register env v;
+            Boolean.set_true env v
+        | _ -> ()
       in
       InvertedPath.add_callback env raup_pattern_bis raup_run_bis
     in
@@ -353,7 +350,7 @@ module Theory = struct
 
   (* K⇓: a = K(v), a[j] |> a[j] = v *)
   let const_read_pattern, const_read_run =
-    let term = mk_select_term (apply_array_const tv) tj in
+    let term = mk_select_term (Builtin.apply_array_const tv) tj in
     let const_read_pattern =
       Pattern.of_term_exn ~subst:Ground.Subst.empty term
     in
@@ -417,13 +414,37 @@ module Theory = struct
     | [] -> ()
 
   let apply_res_ext_2 env =
-    let module HT = Datastructure.Hashtbl (Ground.Ty) in
+    let apply_ext env a b =
+      let gtya, gtyb = array_gty_args (List.hd (node_tyl env a)) in
+      (* TODO: fix *)
+      let n =
+        convert
+          ~subst:
+            (mk_subst
+               [ (va, a); (vb, b) ]
+               [ (ind_ty_var, gtya); (val_ty_var, gtyb) ])
+          env
+          (Expr.Term._or
+             [
+               Expr.Term.eq ta tb;
+               Expr.Term.distinct
+                 [
+                   Expr.Term.Array.select ta (Builtin.apply_array_diff ta tb);
+                   Expr.Term.Array.select tb (Builtin.apply_array_diff ta tb);
+                 ];
+             ])
+      in
+      Debug.dprintf4 debug "Application of the extensionality rule on %a and %a"
+        Node.pp a Node.pp b;
+      Egraph.register env n;
+      Boolean.set_true env n
+    in
     (* Is this reliable? *)
-    let foreign_array_db = HT.create Node.S.pp "foreign_array_db" in
+    let foreign_array_db = GTHT.create Node.S.pp "foreign_array_db" in
     fun (aty : Ground.Ty.t) a ->
       match aty with
       | { app = { builtin = Expr.Array; _ }; _ } ->
-          HT.change
+          GTHT.change
             (fun opt ->
               match opt with
               | Some arr_set ->
@@ -484,8 +505,8 @@ module Theory = struct
             [ (va, Ground.node f) ]
             [ (ind_ty_var, ind_gty); (val_ty_var, val_gty) ]
         in
-        let epsilon_app = apply_array_def_index ta in
-        let delta_app = apply_array_def_value ta in
+        let epsilon_app = Builtin.apply_array_def_index ta in
+        let delta_app = Builtin.apply_array_def_value ta in
         let epsilon_delta_eq =
           Expr.Term.eq (mk_select_term ta epsilon_app) delta_app
         in
@@ -493,16 +514,6 @@ module Theory = struct
         Egraph.register env n;
         Boolean.set_true env n)
 
-  let mk_vt_list pref n ty =
-    let rec aux l n =
-      if n <= 0 then List.rev l
-      else
-        let v = Expr.Term.Var.mk (Format.sprintf "%s%n" pref n) ty in
-        let t = Expr.Term.of_var v in
-        aux (t :: l) (n - 1)
-    in
-    aux [] n
-
   (* map⇓: a = map(f, b1, ..., bn), a[j] |> a[j] = f(b1[j], ..., bn[j]) *)
   let map_adowm map_term f_term bitl =
     let map_read_pattern =
@@ -523,92 +534,137 @@ module Theory = struct
     in
     (map_read_pattern, map_read_run)
 
-  (** Returns a set of nodes, which represents the indexes of the array reads
-      that are applied on the array terms present in [bjl]. *)
-  let rec read_index_nodes acc env subst (bjl : Expr.Term.t list) =
-    match bjl with
-    | h :: t -> (
-        let n = convert ~subst env h in
-        match Egraph.get_dom env Info.dom n with
-        | Some v -> (
-            match
-              F_Pos.M.find_opt
-                { f = Expr.Term.Const.Array.select; pos = 0 }
-                v.parents
-            with
-            | Some s ->
-                Ground.S.fold
-                  (fun g acc ->
-                    match Ground.sem g with
-                    | { app = { builtin = Expr.Select; _ }; args } ->
-                        Node.S.add (IArray.get args 1) acc
-                    | _ -> acc)
-                  s acc
-            | None -> read_index_nodes acc env subst t)
-        | None -> read_index_nodes acc env subst t)
-    | [] -> acc
-
-  (* map⇑: a = map(f, b1, ..., bn), bk[j] |> a[j] = f(b1[j], ..., bn[j]) *)
-  (* map𝛿: a = map(f, b1, ..., bn)        |> 𝛿(a) = f(𝛿(b1), ..., 𝛿((bn)) *)
-  let map_aup map_term f_term bitl =
+  let apply_map_aup env index_n gt
+      Array_dom.{ bi_ind_ty; bi_val_ty; a_val_ty; f_arity } =
+    let argl = IArray.to_list (Ground.sem gt).args in
+    let f_node = List.hd argl in
+    let arg_nodes = List.tl argl in
+    let a_node = Ground.node gt in
+    let fvar =
+      Expr.Term.Var.mk "f" (Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty)
+    in
+    let fterm = Expr.Term.of_var fvar in
+    let ty_subst =
+      [
+        (ind_ty_var, bi_ind_ty);
+        (alpha_ty_var, bi_val_ty);
+        (val_ty_var, a_val_ty);
+      ]
+    in
+    let t_subst = [ (vj, index_n); (fvar, f_node); (va, a_node) ] in
+    let _, bij_list, t_subst =
+      List.fold_left
+        (fun (n, t_acc, s_acc) node ->
+          let biv = Expr.Term.Var.mk (Format.sprintf "b%n" n) array_ty_alpha in
+          let bit = Expr.Term.of_var biv in
+          (n - 1, Expr.Term.Array.select bit tj :: t_acc, (biv, node) :: s_acc))
+        (f_arity, [], t_subst) (List.rev arg_nodes)
+    in
+    let n =
+      convert
+        ~subst:(mk_subst t_subst ty_subst)
+        env
+        (Expr.Term.eq
+           (Expr.Term.Array.select ta tj)
+           (Expr.Term.apply fterm [] bij_list))
+    in
+    Egraph.register env n;
+    Boolean.set_true env n
+
+  (**  map⇑: [a = map(f, b1, ..., bn), bk[j]] |> [a[j] = f(b1[j], ..., bn[j])] *)
+  let add_array_read_hook, add_array_map_hook =
+    let db = GHT.create Node.S.pp "array_map_read_on_arg" in
+    (* Whenever a bk[j] is encountered, apply the map_aup rule on every map
+       that is a parent of bk and for which the rule was not yet applied with j
+    *)
+    let add_array_read_hook env (index_n : Node.t)
+        (map_info_gm : Array_dom.map_info Ground.M.t) =
+      Ground.M.iter
+        (fun gt map_info ->
+          GHT.change
+            (fun ns_opt ->
+              match ns_opt with
+              | Some ns ->
+                  if Node.S.mem index_n ns then Some ns
+                  else (
+                    apply_map_aup env index_n gt map_info;
+                    Some (Node.S.add index_n ns))
+              | None ->
+                  apply_map_aup env index_n gt map_info;
+                  Some (Node.S.singleton index_n))
+            db env gt)
+        map_info_gm
+    in
+    (* Whenever a map function is encountered, apply the map_aup rule on
+       everyone of it's array children on which a read on a value j happens, if
+       the rule has not yet been applied on that j *)
+    let add_array_map_hook env (gt : Ground.t) (reads : Node.S.t) map_info =
+      Node.S.iter
+        (fun index_n ->
+          GHT.change
+            (fun ns_opt ->
+              match ns_opt with
+              | Some ns ->
+                  if Node.S.mem index_n ns then Some ns
+                  else (
+                    apply_map_aup env index_n gt map_info;
+                    Some (Node.S.add index_n ns))
+              | None ->
+                  apply_map_aup env index_n gt map_info;
+                  Some (Node.S.singleton index_n))
+            db env gt)
+        reads
+    in
+    (add_array_read_hook, add_array_map_hook)
+
+  (** [map𝛿: a = map(f, b1, ..., bn)] |> [𝛿(a) = f(𝛿(b1), ..., 𝛿(bn))] *)
+  let map_def map_term f_term bitl =
     let map_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty map_term in
     let map_run env subst =
       Debug.dprintf2 debug "Found array_map(f,b1, ..., bn) with %a"
         Ground.Subst.pp subst;
-      (* map⇑ *)
-      if Options.get env extended_comb then (
-        Debug.dprintf0 debug "Application of the map_aup rule";
-        Node.S.iter
-          (fun ind ->
-            let term =
-              Expr.Term.eq
-                (mk_select_term map_term tj)
-                (Expr.Term.apply f_term []
-                   (List.map (fun bi -> mk_select_term bi tj) bitl))
-            in
-            let n =
-              convert
-                ~subst:
-                  (Ground.Subst.distinct_union subst
-                     (mk_subst [ (vj, ind) ] []))
-                env term
-            in
-            Egraph.register env n;
-            Boolean.set_true env n)
-          (read_index_nodes Node.S.empty env subst bitl));
       (* map𝛿 *)
-      if Options.get env default_values then (
-        Debug.dprintf0 debug "Application of the map_delta rule";
-        let d_a = apply_cst Builtin.array_default_value [] [ ta ] in
-        let d_bil =
-          List.map
-            (fun bi -> apply_cst Builtin.array_default_value [] [ bi ])
-            bitl
-        in
-        let term = Expr.Term.eq d_a (Expr.Term.apply f_term [] d_bil) in
-        let n = convert ~subst env term in
-        Egraph.register env n;
-        Boolean.set_true env n)
+      Debug.dprintf0 debug "Application of the map_delta rule";
+      let d_bil = List.map (fun bi -> Builtin.apply_array_def_value bi) bitl in
+      let term =
+        Expr.Term.eq
+          (Builtin.apply_array_def_value ta)
+          (Expr.Term.apply f_term [] d_bil)
+      in
+      let n = convert ~subst env term in
+      Egraph.register env n;
+      Boolean.set_true env n
     in
     (map_pattern, map_run)
 
   let new_map =
     let module NM = Datastructure.Memo2 (Popop_stdlib.DInt) in
+    let mk_tlist l n ty =
+      let rec aux l n =
+        if n <= 0 then List.rev l
+        else
+          let v = Expr.Term.Var.mk (Format.sprintf "b%n" n) ty in
+          let t = Expr.Term.of_var v in
+          aux (t :: l) (n - 1)
+      in
+      aux l n
+    in
     let new_map_db =
-      NM.create Fmt.nop "new_map_db" (fun _ env f_arity ->
+      NM.create Fmt.nop "new_map_db" (fun env f_arity ->
           let b_ty = Expr.Ty.array ind_ty alpha_ty in
           let f_ty = Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty in
-          let bitl = mk_vt_list "b" f_arity b_ty in
+          let bitl = mk_tlist [] f_arity b_ty in
           let f_var = Expr.Term.Var.mk "f" f_ty in
           let f_term = Expr.Term.of_var f_var in
-          let map_term =
-            apply_cst (Builtin.array_map f_arity) [] (f_term :: bitl)
-          in
-          let map_pattern, map_run = map_aup map_term f_term bitl in
-          let map_read_pattern, map_read_run = map_adowm map_term f_term bitl in
-          if Options.get env extended_comb then
-            InvertedPath.add_callback env map_read_pattern map_read_run;
-          InvertedPath.add_callback env map_pattern map_run)
+          let map_term = Builtin.apply_array_map f_arity f_term bitl in
+          (if Options.get env extended_comb then
+           let map_adown_pattern, map_adown_run =
+             map_adowm map_term f_term bitl
+           in
+           InvertedPath.add_callback env map_adown_pattern map_adown_run);
+          if Options.get env default_values then
+            let map_def_pattern, map_def_run = map_def map_term f_term bitl in
+            InvertedPath.add_callback env map_def_pattern map_def_run)
     in
     fun env mapf_t ->
       let mapf_s = Ground.sem mapf_t in
@@ -652,7 +708,7 @@ let converter env (f : Ground.t) =
               match gty with
               | { app = { builtin = Expr.Array; _ }; _ } ->
                   Ground.add_ty env n gty;
-                  Foreign_dom.set_dom_apply_hooks env gty n IsForeign
+                  Foreign_dom.set_dom env gty n IsForeign
               | _ -> ())
             args
   | {
@@ -668,22 +724,27 @@ let converter env (f : Ground.t) =
       if Options.get env restrict_ext && ind_gty.app.builtin == Expr.Array then (
         let gty = Ground.Ty.array ind_gty val_gty in
         Ground.add_ty env i gty;
-        Foreign_dom.set_dom_apply_hooks env gty i IsForeign);
-      (* application of the `not default` rule 𝝐≠: v = a[i], i ≠ 𝝐 |> i ≠ 𝝐 *)
-      (* should only be applied if the index sort is infinite,
-         otherwise the blast rule should be applied *)
+        Foreign_dom.set_dom env gty i IsForeign);
       if Options.get env extended_comb then (
+        (* when a new read is encountered, check if map⇑ can be applied *)
+        Array_dom.add_read ~hook:Theory.add_array_read_hook env a i;
+        (* 𝝐≠: v = a[i], i ≠ 𝝐 |> i ≠ 𝝐 *)
+        (* should only be applied if the index sort is infinite,
+           otherwise the blast rule should be applied *)
         let subst =
           mk_subst
             [ (Theory.va, a); (Theory.vi, i) ]
             [ (ind_ty_var, ind_gty); (val_ty_var, val_gty) ]
         in
-        let eps_term = apply_array_def_index Theory.ta in
+        let eps_term = Builtin.apply_array_def_index Theory.ta in
         let eps_node = convert ~subst env eps_term in
         Egraph.register env eps_node;
         if not (Node.equal i eps_node) then (
           (* is this the right equality test? *)
-          Debug.incr array_num;
+          NHT.change
+            (function Some v -> Some (v + 1) | None -> Some 1)
+            array_num_db env eps_node;
+          (* check if num(𝜎) >= size(𝜎) to determine which rule to apply *)
           let i_eps_neq_node =
             convert ~subst env (Expr.Term.neq eps_term Theory.ti)
           in
@@ -719,8 +780,8 @@ let converter env (f : Ground.t) =
       if Options.get env default_values then (
         let eq_term =
           Expr.Term.eq
-            (apply_array_def_value store_term)
-            (apply_array_def_value Theory.ta)
+            (Builtin.apply_array_def_value store_term)
+            (Builtin.apply_array_def_value Theory.ta)
         in
         let eq_node = convert ~subst env eq_term in
         Egraph.register env eq_node;
@@ -743,13 +804,25 @@ let converter env (f : Ground.t) =
         let eq_node =
           convert ~subst env
             (Expr.Term.eq
-               (apply_array_def_value (apply_array_const Theory.tv))
+               (Builtin.apply_array_def_value
+                  (Builtin.apply_array_const Theory.tv))
                Theory.tv)
         in
         Egraph.register env eq_node;
         Boolean.set_true env eq_node)
-  | { app = { builtin = Builtin.Array_map; _ }; _ }
+  | {
+   app = { builtin = Builtin.Array_map; _ };
+   args;
+   tyargs = [ bi_ind_ty; bi_val_ty; a_val_ty ];
+   _;
+  }
     when Options.get env extended_comb || Options.get env default_values ->
+      (if Options.get env extended_comb then
+       let f_arity = IArray.length args - 1 in
+       IArray.iteri args ~f:(fun i n ->
+           if i > 0 then
+             Array_dom.add_map_parent ~hook:Theory.add_array_map_hook env n f
+               { bi_ind_ty; bi_val_ty; a_val_ty; f_arity }));
       Theory.new_map env f
   | _ -> ()
 
diff --git a/colibri2/theories/array/array_dom.ml b/colibri2/theories/array/array_dom.ml
new file mode 100644
index 000000000..daf96b163
--- /dev/null
+++ b/colibri2/theories/array/array_dom.ml
@@ -0,0 +1,104 @@
+(*************************************************************************)
+(*  This file is part of Colibri2.                                       *)
+(*                                                                       *)
+(*  Copyright (C) 2014-2021                                              *)
+(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
+(*           alternatives)                                               *)
+(*    OCamlPro                                                           *)
+(*                                                                       *)
+(*  you can redistribute it and/or modify it under the terms of the GNU  *)
+(*  Lesser General Public License as published by the Free Software      *)
+(*  Foundation, version 2.1.                                             *)
+(*                                                                       *)
+(*  It is distributed in the hope that it will be useful,                *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
+(*  GNU Lesser General Public License for more details.                  *)
+(*                                                                       *)
+(*  See the GNU Lesser General Public License version 2.1                *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
+(*************************************************************************)
+
+open Colibri2_core
+open Colibri2_popop_lib
+open Popop_stdlib
+
+type map_info = {
+  bi_ind_ty : Ground.Ty.t;
+  bi_val_ty : Ground.Ty.t;
+  a_val_ty : Ground.Ty.t;
+  f_arity : DInt.t;
+}
+[@@deriving eq, ord, hash, show]
+
+module AVal = struct
+  module T = struct
+    type t = {
+      reads : Node.S.t; (* read indexes *)
+      map_parents : map_info Ground.M.t; (* parents that are maps *)
+    }
+    [@@deriving eq, ord, hash, show]
+  end
+
+  include T
+  include MkDatatype (T)
+
+  let name = "Array.dom.value"
+end
+
+module D = struct
+  module Value = Value.Register (AVal)
+
+  type t = AVal.t = { reads : Node.S.t; map_parents : map_info Ground.M.t }
+  [@@deriving eq, show]
+
+  let is_singleton _ _ = None
+
+  let key =
+    Dom.Kind.create
+      (module struct
+        type nonrec t = t
+
+        let name = "Array.dom"
+      end)
+
+  let inter _ { reads = r1; map_parents = mp1 }
+      { reads = r2; map_parents = mp2 } =
+    Some
+      { reads = Node.S.union r1 r2; map_parents = Ground.M.set_union mp1 mp2 }
+end
+
+include D
+include Dom.Lattice (D)
+
+let add_read
+    ?(hook : Egraph.wt -> Node.t -> map_info Ground.M.t -> unit =
+      fun _ _ _ -> ()) env n r =
+  match Egraph.get_dom env key n with
+  | Some { reads; map_parents } ->
+      hook env n map_parents;
+      set_dom env n { reads = Node.S.add r reads; map_parents }
+  | None ->
+      set_dom env n { reads = Node.S.singleton r; map_parents = Ground.M.empty }
+
+let add_map_parent
+    ?(hook : Egraph.wt -> Ground.t -> Node.S.t -> map_info -> unit =
+      fun _ _ _ _ -> ()) env n gt { bi_ind_ty; bi_val_ty; a_val_ty; f_arity } =
+  match Egraph.get_dom env key n with
+  | Some { reads; map_parents } ->
+      hook env gt reads { bi_ind_ty; bi_val_ty; a_val_ty; f_arity };
+      set_dom env n
+        {
+          reads;
+          map_parents =
+            Ground.M.add gt
+              { bi_ind_ty; bi_val_ty; a_val_ty; f_arity }
+              map_parents;
+        }
+  | None ->
+      set_dom env n
+        {
+          reads = Node.S.empty;
+          map_parents =
+            Ground.M.singleton gt { bi_ind_ty; bi_val_ty; a_val_ty; f_arity };
+        }
diff --git a/colibri2/theories/array/array_dom.mli b/colibri2/theories/array/array_dom.mli
new file mode 100644
index 000000000..70a311e66
--- /dev/null
+++ b/colibri2/theories/array/array_dom.mli
@@ -0,0 +1,53 @@
+(*************************************************************************)
+(*  This file is part of Colibri2.                                       *)
+(*                                                                       *)
+(*  Copyright (C) 2014-2021                                              *)
+(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
+(*           alternatives)                                               *)
+(*    OCamlPro                                                           *)
+(*                                                                       *)
+(*  you can redistribute it and/or modify it under the terms of the GNU  *)
+(*  Lesser General Public License as published by the Free Software      *)
+(*  Foundation, version 2.1.                                             *)
+(*                                                                       *)
+(*  It is distributed in the hope that it will be useful,                *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
+(*  GNU Lesser General Public License for more details.                  *)
+(*                                                                       *)
+(*  See the GNU Lesser General Public License version 2.1                *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
+(*************************************************************************)
+
+type map_info = {
+  bi_ind_ty : Ground.Ty.t;
+  bi_val_ty : Ground.Ty.t;
+  a_val_ty : Ground.Ty.t;
+  f_arity : int;
+}
+
+type t = { reads : Node.S.t; map_parents : map_info Ground.M.t }
+
+val equal : t -> t -> bool
+val pp : Format.formatter -> t -> unit
+val show : t -> string
+val is_singleton : 'a -> 'b -> 'c option
+val key : t Dom.Kind.t
+val inter : 'a -> t -> t -> t option
+val set_dom : Egraph.wt -> Node.t -> t -> unit
+val upd_dom : Egraph.wt -> Node.t -> t -> unit
+
+val add_read :
+  ?hook:(Egraph.wt -> Node.t -> map_info Ground.M.t -> unit) ->
+  Egraph.wt ->
+  Node.t ->
+  Node.t ->
+  unit
+
+val add_map_parent :
+  ?hook:(Egraph.wt -> Ground.t -> Node.S.t -> map_info -> unit) ->
+  Egraph.wt ->
+  Node.t ->
+  Ground.t ->
+  map_info ->
+  unit
diff --git a/colibri2/theories/array/foreign_dom.ml b/colibri2/theories/array/foreign_dom.ml
index f96d0397b..777dce3c6 100644
--- a/colibri2/theories/array/foreign_dom.ml
+++ b/colibri2/theories/array/foreign_dom.ml
@@ -27,7 +27,7 @@ module FVal = struct
   module T = struct
     type t = IsForeign [@@deriving eq, ord, hash]
 
-    let pp fmt = function IsForeign -> Fmt.pf fmt "IsForeign"
+    let pp fmt IsForeign = Fmt.pf fmt "IsForeign"
   end
 
   include T
@@ -62,13 +62,13 @@ let new_foreign_array_hooks :
     (Ground.Ty.t -> Node.t -> unit) Datastructure.Push.t =
   Datastructure.Push.create Fmt.nop "Array.new_foreign_array"
 
-let register_hook_new_foreign_array d (f : Ground.Ty.t -> Node.t -> unit) =
-  Datastructure.Push.push new_foreign_array_hooks d f
+let register_hook_new_foreign_array env (f : Ground.Ty.t -> Node.t -> unit) =
+  Datastructure.Push.push new_foreign_array_hooks env f
 
-let set_dom_apply_hooks env ty n d =
+let set_dom env ty n d =
   set_dom env n d;
   Datastructure.Push.iter ~f:(fun f -> f ty n) new_foreign_array_hooks env
 
-let upd_dom_apply_hooks env ty n d =
+let upd_dom env ty n d =
   upd_dom env n d;
   Datastructure.Push.iter ~f:(fun f -> f ty n) new_foreign_array_hooks env
diff --git a/colibri2/theories/array/foreign_dom.mli b/colibri2/theories/array/foreign_dom.mli
index da6893892..3a33ce41c 100644
--- a/colibri2/theories/array/foreign_dom.mli
+++ b/colibri2/theories/array/foreign_dom.mli
@@ -26,11 +26,9 @@ val is_singleton : Egraph.wt -> t -> Colibri2_core.Value.t option
 val key : t Dom.Kind.t
 val inter : Egraph.wt -> t -> t -> t option
 val pp : t Fmt.t
-val set_dom : Egraph.wt -> Node.t -> t -> unit
-val upd_dom : Egraph.wt -> Node.t -> t -> unit
 
 val register_hook_new_foreign_array :
   Egraph.wt -> (Ground.Ty.t -> Node.t -> unit) -> unit
 
-val set_dom_apply_hooks : Egraph.wt -> Ground.Ty.t -> Node.t -> t -> unit
-val upd_dom_apply_hooks : Egraph.wt -> Ground.Ty.t -> Node.t -> t -> unit
+val set_dom : Egraph.wt -> Ground.Ty.t -> Node.t -> t -> unit
+val upd_dom : Egraph.wt -> Ground.Ty.t -> Node.t -> t -> unit
-- 
GitLab