Skip to content
Snippets Groups Projects
RWRules.ml 28.64 KiB
(*************************************************************************)
(*  This file is part of Colibri2.                                       *)
(*                                                                       *)
(*  Copyright (C) 2014-2021                                              *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*    OCamlPro                                                           *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
(*************************************************************************)

open Colibri2_core
open Colibri2_popop_lib
open Popop_stdlib
open Common
open Colibri2_theories_quantifiers
module GHT = Datastructure.Hashtbl (Ground)
module GTHT = Datastructure.Hashtbl (Ground.Ty)

(** disjunction 1: (a,b,i,j) -> (i = j) \/ (a[j] = b[j]) *)
module D1db =
  SHT
    (struct
      include I4

      let sort (a, b, i, j, ty1, ty2) =
        let a, b = if a <= b then (a, b) else (b, a) in
        let i, j = if i <= j then (i, j) else (j, i) in
        (a, b, i, j, ty1, ty2)
    end)
    (struct
      type t = unit

      let name = "D1db"
      let pp = Pp.unit
    end)

(** disjunction 2: (a,b) -> (a = b) ⋁ (a[k] ≠ b[k]) *)
module D2db =
  SHT
    (struct
      include I2

      let sort (a, b, ty1, ty2) =
        let a, b = if a <= b then (a, b) else (b, a) in
        (a, b, ty1, ty2)
    end)
    (struct
      type t = unit

      let name = "D2db"
      let pp = Pp.unit
    end)

(** (a,b,i,j):
    a -> (true,b,i,j) | b -> (false,a,i,j) |
    i -> (true,a,b,j) | j -> (false,a,b,i) *)
module D1ids = struct
  include MkIHT (struct
    include I3.S

    let name = "D1Ids"
  end)

  let add env (id1, id2, id3, id4, ty1, ty2) =
    let aux env id v =
      change
        ~f:(function
          | Some s -> Some (I3.S.add v s) | None -> Some (I3.S.singleton v))
        env id
    in
    aux env id1 (true, id2, id3, id4, ty1, ty2);
    aux env id2 (false, id1, id3, id4, ty1, ty2);
    aux env id3 (true, id1, id2, id4, ty1, ty2);
    aux env id4 (false, id1, id2, id3, ty1, ty2)
end

(** (a,b): a -> (true,b) | b -> (false,a) *)
module D2ids = struct
  include MkIHT (struct
    include I1.S

    let name = "D2Ids"
  end)

  let add env (id1, id2, ty1, ty2) =
    let aux env id v =
      change
        ~f:(function
          | Some s -> Some (I1.S.add v s) | None -> Some (I1.S.singleton v))
        env id
    in
    aux env id1 (true, id2, ty1, ty2);
    aux env id2 (false, id1, ty1, ty2)
end

let get_disj1_nodes env (subst : Ground.Subst.t) =
  let b_n = Expr.Term.Var.M.find STV.vb subst.term in
  let i_n = Expr.Term.Var.M.find STV.vi subst.term in
  let j_n = Expr.Term.Var.M.find STV.vj subst.term in
  let ind_gty = Expr.Ty.Var.M.find STV.ind_ty_var subst.ty in
  let val_gty = Expr.Ty.Var.M.find STV.val_ty_var subst.ty in
  let a_n =
    try Expr.Term.Var.M.find STV.va subst.term
    with Not_found ->
      let v_n = Expr.Term.Var.M.find STV.vv subst.term in
      ground_apply env Expr.Term.Const.Array.store [ ind_gty; val_gty ]
        [ b_n; i_n; v_n ]
  in
  (a_n, b_n, i_n, j_n, ind_gty, val_gty)

let new_disj1_aux env subst a_n b_n i_n j_n ind_gty val_gty =
  match
    let a_id = Id.Array.get_id env a_n in
    let b_id = Id.Array.get_id env b_n in
    let i_id = Id.Index.get_id env i_n in
    let j_id = Id.Index.get_id env j_n in
    D1db.find_opt env (a_id, b_id, i_id, j_id, ind_gty, val_gty)
  with
  | Some () -> ()
  | None | (exception NoIdFound _) ->
      let v =
        convert ~subst env
          (Expr.Term._or
             [
               Expr.Term.eq STV.ti STV.tj;
               Expr.Term.eq
                 (mk_select_term (mk_store_term STV.tb STV.ti STV.tv) STV.tj)
                 (mk_select_term STV.tb STV.tj);
             ])
      in
      Egraph.register env v;
      Boolean.set_true env v;
      Id.Array.set_id env a_n;
      let a_id = Id.Array.get_id env a_n in
      let b_id = Id.Array.get_id env b_n in
      let i_id = Id.Index.get_id env i_n in
      let j_id = Id.Index.get_id env j_n in
      D1db.set env (a_id, b_id, i_id, j_id, ind_gty, val_gty) ();
      D1ids.add env (a_id, b_id, i_id, j_id, ind_gty, val_gty)

let new_disj1 env subst =
  let a_n, b_n, i_n, j_n, ind_gty, val_gty = get_disj1_nodes env subst in
  new_disj1_aux env subst a_n b_n i_n j_n ind_gty val_gty

let new_disj1_raup2 env subst =
  let a_n, b_n, i_n, j_n, ind_gty, val_gty = get_disj1_nodes env subst in
  match Egraph.get_dom env Linearity_dom.key b_n with
  | Some NonLinear ->
      Debug.dprintf2 debug "Apply raup2 with %a" Ground.Subst.pp subst;
      new_disj1_aux env subst a_n b_n i_n j_n ind_gty val_gty
  | _ ->
      Debug.dprintf2 debug "Do not apply raup2: %a is not non-linear" Node.pp
        b_n

let new_dist_arrays env (a_n, a_id) (b_n, b_id) ind_gty val_gty =
  match D2db.find_opt env (a_id, b_id, ind_gty, val_gty) with
  | Some () -> ()
  | None ->
      let diseq = mk_distinct_arrays env a_n b_n ind_gty val_gty in
      Egraph.register env diseq;
      Boolean.set_true env diseq;
      D2db.set env (a_id, b_id, ind_gty, val_gty) ();
      D2ids.add env (a_id, b_id, ind_gty, val_gty)

let new_disj2 env (a, a_id) (b, b_id) ind_gty val_gty =
  match D2db.find_opt env (a_id, b_id, ind_gty, val_gty) with
  | Some () -> ()
  | None ->
      let eq = Equality.equality env [ a; b ] in
      let diseq = mk_distinct_arrays env a b ind_gty val_gty in
      Debug.dprintf4 debug "Application of the extensionality rule on %a and %a"
        Node.pp a Node.pp b;
      Egraph.register env eq;
      Egraph.register env diseq;
      Choice.register_global env
        {
          print_cho = "Decision from extensionality application.";
          prio = 1;
          choice =
            (fun env ->
              match (Boolean.is env eq, Boolean.is env diseq) with
              | Some true, _ -> DecNo
              | _, Some true -> DecNo
              | _ ->
                  DecTodo
                    [
                      (fun env ->
                        Boolean.set_true env eq;
                        Boolean.set_false env diseq);
                      (fun env ->
                        Boolean.set_false env eq;
                        Boolean.set_true env diseq);
                    ]);
        };
      D2db.set env (a_id, b_id, ind_gty, val_gty) ();
      D2ids.add env (a_id, b_id, ind_gty, val_gty)

let eq_arrays_norm env (_, kid) (_, rid) _ =
  (match D1ids.find_opt env rid with
  | None -> ()
  | Some s ->
      let ns =
        I3.S.fold
          (fun (b, oid, iid, jid, ty1, ty2) s ->
            let ns, (ofst, osnd, nfst, nsnd) =
              if oid = rid then
                ( I3.S.add
                    (b, kid, iid, jid, ty1, ty2)
                    (I3.S.remove (b, oid, iid, jid, ty1, ty2) s),
                  if b then (rid, oid, kid, kid) else (oid, rid, kid, kid) )
              else (s, if b then (rid, oid, kid, oid) else (oid, rid, oid, kid))
            in
            D1db.remove env (ofst, osnd, iid, jid, ty1, ty2);
            D1db.set env (nfst, nsnd, iid, jid, ty1, ty2) ();
            D1ids.change env iid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (true, nfst, nsnd, jid, ty1, ty2)
                       (I3.S.remove (true, ofst, osnd, jid, ty1, ty2) s)));
            D1ids.change env jid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (false, nfst, nsnd, iid, ty1, ty2)
                       (I3.S.remove (false, ofst, osnd, iid, ty1, ty2) s)));
            D1ids.change env oid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (not b, kid, iid, jid, ty1, ty2)
                       (I3.S.remove (not b, rid, iid, jid, ty1, ty2) s)));
            ns)
          s s
      in
      D1ids.change env kid ~f:(function
        | None -> Some ns
        | Some s' -> Some (I3.S.union ns s')));
  (match D2ids.find_opt env rid with
  | None -> ()
  | Some s ->
      let ns =
        I1.S.fold
          (fun (b, oid, ty1, ty2) s ->
            let ns, (ofst, osnd, nfst, nsnd) =
              if oid = rid then
                ( I1.S.add (b, kid, ty1, ty2) (I1.S.remove (b, oid, ty1, ty2) s),
                  if b then (rid, oid, kid, kid) else (oid, rid, kid, kid) )
              else (s, if b then (rid, oid, kid, oid) else (oid, rid, oid, kid))
            in
            D2db.remove env (ofst, osnd, ty1, ty2);
            D2db.set env (nfst, nsnd, ty1, ty2) ();
            D2ids.change env oid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I1.S.add (not b, kid, ty1, ty2)
                       (I1.S.remove (not b, rid, ty1, ty2) s)));
            ns)
          s s
      in
      D2ids.change env kid ~f:(function
        | None -> Some ns
        | Some s' -> Some (I1.S.union ns s')));
  D2ids.remove env rid;
  D1ids.remove env rid

let eq_indices_norm env (_, kid) (_, rid) _ =
  (match D1ids.find_opt env rid with
  | None -> ()
  | Some s ->
      let ns =
        I3.S.fold
          (fun (b, aid, bid, oid, ty1, ty2) s ->
            let ns, (ofst, osnd, nfst, nsnd) =
              if oid = rid then
                ( I3.S.add
                    (b, aid, bid, kid, ty1, ty2)
                    (I3.S.remove (b, aid, bid, oid, ty1, ty2) s),
                  if b then (rid, oid, kid, kid) else (oid, rid, kid, kid) )
              else (s, if b then (rid, oid, kid, oid) else (oid, rid, oid, kid))
            in
            D1db.remove env (aid, bid, ofst, osnd, ty1, ty2);
            D1db.set env (aid, bid, nfst, nsnd, ty1, ty2) ();
            D1ids.change env aid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (true, bid, nfst, nsnd, ty1, ty2)
                       (I3.S.remove (true, bid, ofst, osnd, ty1, ty2) s)));
            D1ids.change env bid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (false, aid, nfst, nsnd, ty1, ty2)
                       (I3.S.remove (false, aid, ofst, osnd, ty1, ty2) s)));
            D1ids.change env oid ~f:(function
              | None -> assert false
              | Some s ->
                  Some
                    (I3.S.add
                       (not b, aid, bid, kid, ty1, ty2)
                       (I3.S.remove (not b, aid, bid, rid, ty1, ty2) s)));
            ();
            ns)
          s s
      in
      D1ids.change env kid ~f:(function
        | None -> Some ns
        | Some s' -> Some (I3.S.union ns s')));
  D1ids.remove env rid

type size = Inf | Finite of { num : int; size : int } [@@deriving show]

let check_gty_num_size =
  (* TODO: sizes for all types.
     Will be more useful if there are more bounded types. *)
  let gty_size (gty : Ground.Ty.t) =
    match gty with
    | { app = { builtin = Expr.Prop; _ }; _ } -> Finite { num = 1; size = 2 }
    | _ -> Inf
  in
  let gty_ns_db = GTHT.create pp_size "array_size_db" in
  fun (env : Egraph.rw Egraph.t) gty ->
    match GTHT.find_opt gty_ns_db env gty with
    | Some Inf -> false
    | Some (Finite { num; size }) ->
        if num >= size then true
        else (
          GTHT.set gty_ns_db env gty (Finite { num = num + 1; size });
          false)
    | None ->
        GTHT.set gty_ns_db env gty (gty_size gty);
        false

(* ⇓: a ≡ b[i <- v], a[j] |> (i = j) \/ a[j] = b[j] *)
let adown_pattern, adown_run =
  (* (a,b,i,j) *)
  let a_term = mk_store_term STV.tb STV.ti STV.tv in
  let term = mk_select_term a_term STV.tj in
  let adown_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty term in
  let adown_run env subst =
    Debug.dprintf2 debug "Found adown with %a" Ground.Subst.pp subst;
    new_disj1 env subst
  in
  (adown_pattern, adown_run)

(* ⇑: a ≡ b[i <- v], b[j]  |> (i = j) \/ a[j] = b[j] *)
let aup_pattern, aup_run =
  let term = mk_store_term STV.tb STV.ti STV.tv in
  let aup_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty term in
  let aup_run env subst =
    let n = convert ~subst env term in
    Egraph.register env n;
    Debug.dprintf2 debug "Found aup1 with %a" Ground.Subst.pp subst;
    let term_bis = mk_select_term STV.tb STV.tj in
    let aup_pattern_bis = Pattern.of_term_exn ~subst term_bis in
    let aup_run_bis env subst_bis =
      let subst = Ground.Subst.distinct_union subst_bis subst in
      (* (a,b,i,j) *)
      Debug.dprintf2 debug "Found aup2 with %a" Ground.Subst.pp subst;
      new_disj1 env subst
    in
    InvertedPath.add_callback env aup_pattern_bis aup_run_bis
  in
  (aup_pattern, aup_run)

(* ⇑ᵣ: a ≡ b[i <- v], b[j], b ∈ non-linear |> (i = j) \/ a[j] = b[j] *)
let raup_pattern, raup_run =
  (* (a,b,i,j) *)
  let term = mk_store_term STV.tb STV.ti STV.tv in
  let raup_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty term in
  let raup_run env subst =
    let n = convert ~subst env term in
    Egraph.register env n;
    Debug.dprintf2 debug "Found raup1 with %a" Ground.Subst.pp subst;
    let term_bis = mk_select_term STV.tb STV.tj in
    let raup_pattern_bis = Pattern.of_term_exn ~subst term_bis in
    let raup_run_bis env subst_bis =
      let subst = Ground.Subst.distinct_union subst_bis subst in
      Debug.dprintf2 debug "Found raup2 with %a" Ground.Subst.pp subst;
      new_disj1_raup2 env subst
    in
    InvertedPath.add_callback env raup_pattern_bis raup_run_bis
  in
  (raup_pattern, raup_run)

(* K⇓: a = K(v), a[j] |> a[j] = v *)
let const_read_pattern, const_read_run =
  let term = mk_select_term (Builtin.apply_array_const STV.tv) STV.tj in
  let const_read_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty term in
  let const_read_run env subst =
    Debug.dprintf2 debug "Found const_read with %a" Ground.Subst.pp subst;
    let v = convert ~subst env (Expr.Term.eq term STV.tv) in
    Egraph.register env v;
    Boolean.set_true env v
  in
  (const_read_pattern, const_read_run)

let apply_res_ext_1_1_aux env ind_gty val_gty l =
  Debug.dprintf2 debug "Application of the res-ext-1-1 rule on %a"
    (Fmt.list ~sep:Fmt.comma Node.pp)
    l;
  let rec aux l =
    match l with
    | [] -> ()
    | n1 :: t ->
        let id1 = Id.Array.get_id env n1 in
        List.iter
          (fun n2 ->
            (* (a,b) *)
            let id2 = Id.Array.get_id env n2 in
            new_dist_arrays env (n1, id1) (n2, id2) ind_gty val_gty)
          t;
        aux t
  in
  aux l

let apply_res_ext_1_2_aux env ind_gty val_gty l =
  Debug.dprintf2 debug "Application of the res-ext-1-2 rule on %a"
    (Fmt.list ~sep:Fmt.comma Node.pp)
    l;
  let rec aux2 l =
    match l with
    | [] -> ()
    | (n1, id1) :: t -> (
        match WEGraph.WEG.find_opt env id1 with
        | None -> aux2 t
        | Some (_, m) ->
            List.iter
              (fun (n2, id2) ->
                if DInt.M.mem id2 m then
                  (* (a,b) *)
                  new_dist_arrays env (n1, id1) (n2, id2) ind_gty val_gty)
              t)
  in
  let rec aux1 l =
    match l with
    | [] -> ()
    | n1 :: t -> (
        let id1 = Id.Array.get_id env n1 in
        match WEGraph.WEG.find_opt env id1 with
        | Some (_, m) ->
            let t' =
              List.rev_map
                (fun n2 ->
                  let id2 = Id.Array.get_id env n2 in
                  if DInt.M.mem id2 m then
                    (* (a,b) *)
                    new_dist_arrays env (n1, id1) (n2, id2) ind_gty val_gty;
                  (n2, id2))
                t
            in
            aux2 t'
        | None -> aux1 t)
  in
  aux1 l

(* (a = b) ≡ false |> (a[k] ≠ b[k]) *)
let apply_res_ext_1_1, apply_res_ext_1_2 =
  let apply_res_ext_1 f env s =
    let l = Node.S.elements s in
    match get_array_gty_args env (List.hd l) with
    | ind_gty, val_gty -> f env ind_gty val_gty l
    | exception Not_An_Array _ -> ()
  in
  let apply_res_ext_1_1 = apply_res_ext_1 apply_res_ext_1_1_aux in
  let apply_res_ext_1_2 = apply_res_ext_1 apply_res_ext_1_2_aux in
  (apply_res_ext_1_1, apply_res_ext_1_2)

(** given a new foreign array and a set of all the other foreign arrays,
    applies res-ext-2 on the new foreign array and all the other foregin arrays
    which are its neighbours in the WEGraph *)
let get_foreign_neighbours env a arr_set =
  let aid = Id.Array.get_id env a in
  let nps = WEGraph.get_neighbours env aid in
  Node.S.fold
    (fun node acc ->
      if DInt.S.mem (Id.Array.get_id env node) nps then Node.S.add node acc
      else acc)
    arr_set Node.S.empty

(* a, b, {a,b} ⊆ foreign |> (a = b) ⋁ (a[k] ≠ b[k]) *)
let apply_res_ext_2_1, apply_res_ext_2_2 =
  let foreign_array_db = GTHT.create Node.S.pp "foreign_array_db" in
  let apply_res_ext_2 f env aty a =
    match aty with
    | Ground.Ty.{ app = { builtin = Expr.Array; _ }; _ } ->
        GTHT.change
          (fun opt ->
            match opt with
            | Some fa_set ->
                Debug.dprintf2 debug
                  "Found new foreign array (%a) on which to apply \
                   new_foreign_array the hook"
                  Node.pp a;
                f env a fa_set;
                Some (Node.S.add a fa_set)
            | None -> Some (Node.S.singleton a))
          foreign_array_db env aty
    | _ -> ()
  in
  let apply_res_ext_2_1 env aty a =
    Debug.dprintf2 debug "Application of the res-ext-2-1 rule on %a" Node.pp a;
    apply_res_ext_2
      (fun env a fa_set ->
        let ind_gty, val_gty = array_gty_args (get_array_gty env a) in
        let id1 = Id.Array.get_id env a in
        Node.S.iter
          (fun b ->
            let id2 = Id.Array.get_id env b in
            (* (a, b) *)
            new_disj2 env (a, id1) (b, id2) ind_gty val_gty)
          fa_set)
      env aty a
  in
  let apply_res_ext_2_2 env aty a =
    Debug.dprintf2 debug "Application of the res-ext-2-2 rule on %a" Node.pp a;
    apply_res_ext_2
      (fun env a fa_set ->
        let ind_gty, val_gty = array_gty_args (get_array_gty env a) in
        let id1 = Id.Array.get_id env a in
        Node.S.iter
          (fun b ->
            let id2 = Id.Array.get_id env b in
            (* (a, b) *)
            new_disj2 env (a, id1) (b, id2) ind_gty val_gty)
          (get_foreign_neighbours env a fa_set))
      env aty a
  in
  (apply_res_ext_2_1, apply_res_ext_2_2)

let new_array =
  let module GHT = Datastructure.Hashtbl (Ground.Ty) in
  let db_gty = GHT.create Ground.S.pp "known_array_ht" in
  fun env ind_gty val_gty f ->
    (* Extensionality rule ext: a, b ⇒ (a = b) ⋁ (a[k] ≠ b[k]) *)
    let agty = Ground.Ty.array ind_gty val_gty in
    (if Options.get env no_res_ext then
       match GHT.find_opt db_gty env agty with
       | Some s ->
           Ground.S.iter
             (fun f2 ->
               let a = Ground.node f in
               let b = Ground.node f2 in
               Choice.register_global env
                 {
                   print_cho = "Decision from the application of raup.";
                   prio = 1;
                   choice =
                     (fun env ->
                       let abneq = mk_distinct_arrays env a b ind_gty val_gty in
                       let abeq = Equality.equality env [ a; b ] in
                       Egraph.register env abneq;
                       Egraph.register env abeq;
                       match (Boolean.is env abeq, Boolean.is env abneq) with
                       | Some true, _ | _, Some true -> DecNo
                       | _ ->
                           DecTodo
                             [
                               (fun env ->
                                 Debug.dprintf4 debug
                                   "Apply Ext.1: set %a to true; %a to false"
                                   Node.pp abeq Node.pp abneq;
                                 Boolean.set_true env abeq;
                                 Boolean.set_false env abneq);
                               (fun env ->
                                 Debug.dprintf4 debug
                                   "Apply Ext.2: set %a to false; %a to true"
                                   Node.pp abeq Node.pp abneq;
                                 Boolean.set_false env abeq;
                                 Boolean.set_true env abneq);
                             ]);
                 })
             s
       | None -> ());
    GHT.change
      (function
        | Some s -> Some (Ground.S.add f s)
        | None -> Some (Ground.S.singleton f))
      db_gty env agty;
    (* 𝝐𝛿: a |> a[𝝐] = 𝛿 *)
    if Options.get env extended_comb then (
      Debug.dprintf0 debug "Application of the epsilon_delta rule";
      let a = Ground.node f in
      let def_ind =
        ground_apply env Builtin.array_default_index [ ind_gty; val_gty ] [ a ]
      in
      let def_val =
        ground_apply env Builtin.array_default_value [ ind_gty; val_gty ] [ a ]
      in
      let select_n = mk_select env a def_ind ind_gty val_gty in
      let n = Equality.equality env [ select_n; def_val ] in
      Egraph.register env n;
      Boolean.set_true env n)

(* map⇓: a = map(f, b1, ..., bn), a[j] |> a[j] = f(b1[j], ..., bn[j]) *)
let map_adowm map_term f_term bitl =
  let map_read_pattern =
    Pattern.of_term_exn ~subst:Ground.Subst.empty
      (mk_select_term map_term STV.tj)
  in
  let map_read_run env subst =
    Debug.dprintf2 debug "Found array_map(f,b1, ..., bn)[j] with %a"
      Ground.Subst.pp subst;
    let term =
      Expr.Term.eq
        (mk_select_term map_term STV.tj)
        (Expr.Term.apply f_term []
           (List.map (fun bi -> mk_select_term bi STV.tj) bitl))
    in
    let n = convert ~subst env term in
    Egraph.register env n;
    Boolean.set_true env n
  in
  (map_read_pattern, map_read_run)

let apply_map_aup env index_n gt
    Array_dom.{ bi_ind_ty; bi_val_ty; a_val_ty; f_arity } =
  let argl = IArray.to_list (Ground.sem gt).args in
  let f_node = List.hd argl in
  let arg_nodes = List.tl argl in
  let a_node = Ground.node gt in
  let fvar =
    Expr.Term.Var.mk "f"
      (Expr.Ty.arrow (replicate f_arity STV.alpha_ty) STV.val_ty)
  in
  let fterm = Expr.Term.of_var fvar in
  let ty_subst =
    [
      (STV.ind_ty_var, bi_ind_ty);
      (STV.alpha_ty_var, bi_val_ty);
      (STV.val_ty_var, a_val_ty);
    ]
  in
  let t_subst = [ (STV.vj, index_n); (fvar, f_node); (STV.va, a_node) ] in
  let _, bij_list, t_subst =
    List.fold_left
      (fun (n, t_acc, s_acc) node ->
        let biv =
          Expr.Term.Var.mk (Format.sprintf "b%n" n) STV.array_ty_alpha
        in
        let bit = Expr.Term.of_var biv in
        (n - 1, Expr.Term.Array.select bit STV.tj :: t_acc, (biv, node) :: s_acc))
      (f_arity, [], t_subst) (List.rev arg_nodes)
  in
  let n =
    convert
      ~subst:(mk_subst t_subst ty_subst)
      env
      (Expr.Term.eq
         (Expr.Term.Array.select STV.ta STV.tj)
         (Expr.Term.apply fterm [] bij_list))
  in
  Egraph.register env n;
  Boolean.set_true env n

(**  map⇑: [a = map(f, b1, ..., bn), bk[j]] |> [a[j] = f(b1[j], ..., bn[j])] *)
let add_array_read_hook, add_array_map_hook =
  let db = GHT.create Node.S.pp "array_map_read_on_arg" in
  (* Whenever a bk[j] is encountered, apply the map_aup rule on every map
     that is a parent of bk and for which the rule was not yet applied with j
  *)
  let add_array_read_hook env (index_n : Node.t)
      (map_info_gm : Array_dom.map_info Ground.M.t) =
    Ground.M.iter
      (fun gt map_info ->
        GHT.change
          (fun ns_opt ->
            match ns_opt with
            | Some ns ->
                if Node.S.mem index_n ns then Some ns
                else (
                  apply_map_aup env index_n gt map_info;
                  Some (Node.S.add index_n ns))
            | None ->
                apply_map_aup env index_n gt map_info;
                Some (Node.S.singleton index_n))
          db env gt)
      map_info_gm
  in
  (* Whenever a map function is encountered, apply the map_aup rule on
     everyone of it's array children on which a read on a value j happens, if
     the rule has not yet been applied on that j *)
  let add_array_map_hook env (gt : Ground.t) (reads : Node.S.t) map_info =
    Node.S.iter
      (fun index_n ->
        GHT.change
          (fun ns_opt ->
            match ns_opt with
            | Some ns ->
                if Node.S.mem index_n ns then Some ns
                else (
                  apply_map_aup env index_n gt map_info;
                  Some (Node.S.add index_n ns))
            | None ->
                apply_map_aup env index_n gt map_info;
                Some (Node.S.singleton index_n))
          db env gt)
      reads
  in
  (add_array_read_hook, add_array_map_hook)

(** [map𝛿: a = map(f, b1, ..., bn)] |> [𝛿(a) = f(𝛿(b1), ..., 𝛿(bn))] *)
let map_def map_term f_term bitl =
  let map_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty map_term in
  let map_run env subst =
    Debug.dprintf2 debug "Found array_map(f,b1, ..., bn) with %a"
      Ground.Subst.pp subst;
    (* map𝛿 *)
    Debug.dprintf0 debug "Application of the map_delta rule";
    let d_bil = List.map (fun bi -> Builtin.apply_array_def_value bi) bitl in
    let term =
      Expr.Term.eq
        (Builtin.apply_array_def_value STV.ta)
        (Expr.Term.apply f_term [] d_bil)
    in
    let n = convert ~subst env term in
    Egraph.register env n;
    Boolean.set_true env n
  in
  (map_pattern, map_run)

let new_map =
  let module NM = Datastructure.Memo2 (DInt) in
  let mk_tlist l n ty =
    let rec aux l n =
      if n <= 0 then List.rev l
      else
        let v = Expr.Term.Var.mk (Format.sprintf "b%n" n) ty in
        let t = Expr.Term.of_var v in
        aux (t :: l) (n - 1)
    in
    aux l n
  in
  let new_map_db =
    NM.create Fmt.nop "new_map_db" (fun env f_arity ->
        let b_ty = Expr.Ty.array STV.ind_ty STV.alpha_ty in
        let f_ty = Expr.Ty.arrow (replicate f_arity STV.alpha_ty) STV.val_ty in
        let bitl = mk_tlist [] f_arity b_ty in
        let f_var = Expr.Term.Var.mk "f" f_ty in
        let f_term = Expr.Term.of_var f_var in
        let map_term = Builtin.apply_array_map f_arity f_term bitl in
        (if Options.get env extended_comb then
           let map_adown_pattern, map_adown_run =
             map_adowm map_term f_term bitl
           in
           InvertedPath.add_callback env map_adown_pattern map_adown_run);
        if Options.get env default_values then
          let map_def_pattern, map_def_run = map_def map_term f_term bitl in
          InvertedPath.add_callback env map_def_pattern map_def_run)
  in
  fun env mapf_t ->
    let mapf_s = Ground.sem mapf_t in
    let f_arity = IArray.length mapf_s.args - 1 in
    NM.find new_map_db env f_arity

(** gets a list of all the inhabitants of a given finite type *)
let get_gty_inhabitants (gty : Ground.Ty.t) =
  match gty with
  | { app = { builtin = Expr.Prop; _ }; _ } ->
      [ Expr.Term._true; Expr.Term._false ]
  | _ ->
      failwith
        (Fmt.str "get_gty_inhabitants: unimplemented for the type %a"
           Ground.Ty.pp gty)

let apply_blast_rule =
  let mk_delta_fv ind_gty n =
    Expr.Term.of_var
      (Expr.Term.Var.mk
         (Fmt.str "delta_%a_%d" Ground.Ty.pp ind_gty n)
         STV.val_ty)
  in
  let mk_conjonction_node env array_node val_gty sigmas =
    let vals, _ =
      List.fold_left
        (fun (acc, n) sigma ->
          let term =
            Expr.Term.eq
              (Expr.Term.Array.select STV.ta sigma)
              (mk_delta_fv val_gty n)
          in
          (term :: acc, n + 1))
        ([], 1) sigmas
    in
    convert
      ~subst:(mk_subst [ (STV.va, array_node) ] [ (STV.val_ty_var, val_gty) ])
      env (Expr.Term._and vals)
  in
  fun env array_node (ind_gty : Ground.Ty.t) (val_gty : Ground.Ty.t) ->
    let n =
      mk_conjonction_node env array_node val_gty (get_gty_inhabitants ind_gty)
    in
    Egraph.register env n;
    Boolean.set_true env n