diff --git a/colibri2/theories/array/array.ml b/colibri2/theories/array/array.ml index a41d097e1529495d84d20876cb1668a9ea58cf20..4a753b9b2df8822a2d2d9176e9e13dc3f7229a34 100644 --- a/colibri2/theories/array/array.ml +++ b/colibri2/theories/array/array.ml @@ -30,20 +30,17 @@ let restrict_ext = let restrict_aup = Colibri2_stdlib.Flags.Solve.gen_register_flag ~rep:(fun flag -> (flag, "Restrict the ⇑ rule")) - "res-aup" + "Array.res-aup" let extended_comb = Colibri2_stdlib.Flags.Solve.gen_register_flag ~rep:(fun flag -> (flag, "Extended combinators")) - "ext-comb" + "Array.ext-comb" let default_values = Colibri2_stdlib.Flags.Solve.gen_register_flag ~rep:(fun flag -> (flag, "Default values")) - "def-values" - -let db = Datastructure.Push.create Ground.pp "Array.db" -(* Use one db per array type? *) + "Array.def-values" let debug = Debug.register_info_flag ~desc:"For array theory" "Array" let stats = Debug.register_stats_int "Array.rule" @@ -58,14 +55,18 @@ let ind_ty_var = Expr.Ty.Var.mk "ind_ty" let val_ty_var = Expr.Ty.Var.mk "val_ty" let ind_ty = Expr.Ty.of_var ind_ty_var let val_ty = Expr.Ty.of_var val_ty_var +let alpha_ty_var = Expr.Ty.Var.mk "alpha" +let alpha_ty = Expr.Ty.of_var alpha_ty_var let array_ty = Expr.Ty.array ind_ty val_ty let bind_tys tyvl ty = Expr.Ty.pi tyvl ty +let replicate n v = List.init n (fun _ -> v) (* Builtins *) module Builtin = struct type _ Expr.t += | Array_diff | Array_const + | Array_map | Array_default_index | Array_default_value @@ -79,6 +80,26 @@ module Builtin = struct (Dolmen_std.Path.global "array_const") (Expr.Ty.arrow [ val_ty ] array_ty) + let array_map : int -> Dolmen_std.Expr.term_cst = + let cache = Popop_stdlib.DInt.H.create 13 in + let get_ty i = + match Popop_stdlib.DInt.H.find cache i with + | ty -> ty + | exception Not_found -> + let ty = + Expr.Ty.arrow + (Expr.Ty.arrow (replicate i alpha_ty) val_ty + :: replicate i (Expr.Ty.array ind_ty alpha_ty)) + val_ty + in + Popop_stdlib.DInt.H.add cache i ty; + ty + in + fun i -> + Expr.Id.mk ~name:"array_map" ~builtin:Array_map + (Dolmen_std.Path.global "array_map") + (get_ty i) + let array_default_index : Dolmen_std.Expr.term_cst = Expr.Id.mk ~name:"array_default_index" ~builtin:Array_default_index (Dolmen_std.Path.global "array_default_index") @@ -116,6 +137,19 @@ module Builtin = struct | Dolmen_loop.Typer.T.Id { ns = Term; name = Simple "array_default_value" } -> app1 env s array_default_value + | Dolmen_loop.Typer.T.Id { ns = Term; name = Simple "array_map" } -> + `Term + (Dolmen_type.Base.term_app_list + (module Dolmen_loop.Typer.T) + env s + (function + | _ :: t as l -> + (* "t" should probably never be empty... *) + Expr.Term.apply_cst (array_map (List.length t)) [] l + | _ -> + failwith + "array_map needs to be applied to a function and n \ + arrays")) | _ -> `Not_found) end @@ -328,35 +362,139 @@ module Theory = struct let l = if !extended_comb then (const_pattern, const_run) :: l else l in List.iter (fun (p, r) -> InvertedPath.add_callback env p r) l - let new_array env s_index_ty s_value_ty f = - (* Extensionality rule ext: a, b ⇒ (a = b) â‹ (a[k] ≠b[k]) *) - if not !restrict_ext then ( - Datastructure.Push.iter db env ~f:(fun f2 -> - let subst = - mk_subst - [ (va, Ground.node f2); (vb, Ground.node f) ] - [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ] - in - Debug.dprintf2 debug "Found ext with %a" Ground.Subst.pp subst; - let n = distinct_term_node ~subst env ta tb in - Egraph.register env n; - Boolean.set_true env n); - Datastructure.Push.push db env f); - (* ðð›¿: a |> a[ð] = 𛿠*) - if !extended_comb then ( - let subst = - mk_subst - [ (va, Ground.node f) ] - [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ] - in - let epsilon_app = apply_def_index ta in - let delta_app = apply_def_value ta in - let epsilon_delta_eq = - Expr.Term.eq (Expr.Term.select ta epsilon_app) delta_app + let new_array = + let db = Datastructure.Push.create Ground.pp "Array.db" in + fun env s_index_ty s_value_ty f -> + (* Extensionality rule ext: a, b ⇒ (a = b) â‹ (a[k] ≠b[k]) *) + if not !restrict_ext then ( + Datastructure.Push.iter db env ~f:(fun f2 -> + let subst = + mk_subst + [ (va, Ground.node f2); (vb, Ground.node f) ] + [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ] + in + Debug.dprintf2 debug "Found ext with %a" Ground.Subst.pp subst; + let n = distinct_term_node ~subst env ta tb in + Egraph.register env n; + Boolean.set_true env n); + Datastructure.Push.push db env f); + (* ðð›¿: a |> a[ð] = 𛿠*) + if !extended_comb then ( + Debug.dprintf0 debug "Application of the epsilon_delta rule"; + let subst = + mk_subst + [ (va, Ground.node f) ] + [ (ind_ty_var, s_index_ty); (val_ty_var, s_value_ty) ] + in + let epsilon_app = apply_def_index ta in + let delta_app = apply_def_value ta in + let epsilon_delta_eq = + Expr.Term.eq (Expr.Term.select ta epsilon_app) delta_app + in + let n = convert ~subst env epsilon_delta_eq in + Egraph.register env n; + Boolean.set_true env n) + + let mk_vt_list pref n ty = + let rec aux tl n = + if n <= 0 then List.rev tl + else + let v = Expr.Term.Var.mk (Format.sprintf "%s%n" pref n) ty in + let t = Expr.Term.of_var v in + aux (t :: tl) (n - 1) + in + aux [] n + + (* map⇓: a = map(f, b1, ..., bn), a[j] |> a[j] = f(b1[j], ..., bn[j]) *) + let map_adowm map_term f_term bitl = + let map_read_pattern = + Pattern.of_term_exn ~subst:Ground.Subst.empty + (Expr.Term.select map_term tj) + in + let map_read_run env subst = + Debug.dprintf2 debug "Found array_map(f,b1, ..., bn)[j] with %a" + Ground.Subst.pp subst; + let term = + Expr.Term.eq + (Expr.Term.select map_term tj) + (Expr.Term.apply f_term [] + (List.map (fun bi -> Expr.Term.select bi tj) bitl)) in - let n = convert ~subst env epsilon_delta_eq in + let n = convert ~subst env term in Egraph.register env n; - Boolean.set_true env n) + Boolean.set_true env n + in + (map_read_pattern, map_read_run) + + (* map⇑: a = map(f, b1, ..., bn), bk[j] |> a[j] = f(b1[j], ..., bn[j]) *) + (* mapð›¿: a = map(f, b1, ..., bn) |> ð›¿(a) = f(ð›¿(b1), ..., ð›¿((bn)) *) + let map_aup map_term f_term bitl = + let map_pattern = Pattern.of_term_exn ~subst:Ground.Subst.empty map_term in + let map_run env subst = + Debug.dprintf2 debug "Found array_map(f,b1, ..., bn) with %a" + Ground.Subst.pp subst; + if !extended_comb then ( + Debug.dprintf0 debug "Application of the map_aup rule"; + let bkjl = List.map (fun bi -> Expr.Term.select bi tj) bitl in + let bkjl_patterns = List.map (Pattern.of_term_exn ~subst) bkjl in + let bkj_run = + let seen = ref false in + (* TODO: find one k for each j that is encountered, if there are + different k's for one j, only one of them needs to be matched. *) + fun env subst -> + if not !seen then ( + seen := true; + let term = + Expr.Term.eq + (Expr.Term.select map_term tj) + (Expr.Term.apply f_term [] + (List.map (fun bi -> Expr.Term.select bi tj) bitl)) + in + let n = convert ~subst env term in + Egraph.register env n; + Boolean.set_true env n) + in + List.iter + (fun pattern -> InvertedPath.add_callback env pattern bkj_run) + bkjl_patterns); + if !default_values then ( + Debug.dprintf0 debug "Application of the map_delta rule"; + let d_a = Expr.Term.apply_cst Builtin.array_default_value [] [ ta ] in + let d_bil = + List.map + (fun bi -> + Expr.Term.apply_cst Builtin.array_default_value [] [ bi ]) + bitl + in + let term = Expr.Term.eq d_a (Expr.Term.apply f_term [] d_bil) in + let n = convert ~subst env term in + Egraph.register env n; + Boolean.set_true env n) + in + (map_pattern, map_run) + + let new_map = + (* Does the type matter? should it be cached? *) + let cache = ref Popop_stdlib.DInt.S.empty in + let seen i = Popop_stdlib.DInt.S.mem i !cache in + let add i = cache := Popop_stdlib.DInt.S.add i !cache in + fun env mapf_t -> + let mapf_s = Ground.sem mapf_t in + let f_arity = IArray.length mapf_s.args - 1 in + if not (seen f_arity) then ( + add f_arity; + let b_ty = Expr.Ty.array ind_ty alpha_ty in + let f_ty = Expr.Ty.arrow (replicate f_arity alpha_ty) val_ty in + let bitl = mk_vt_list "b" f_arity b_ty in + let f_var = Expr.Term.Var.mk "f" f_ty in + let f_term = Expr.Term.of_var f_var in + let map_term = + Expr.Term.apply_cst (Builtin.array_map f_arity) [] (f_term :: bitl) + in + let map_pattern, map_run = map_adowm map_term f_term bitl in + let map_read_pattern, map_read_run = map_adowm map_term f_term bitl in + if !extended_comb then InvertedPath.add_callback env map_pattern map_run; + InvertedPath.add_callback env map_read_pattern map_read_run) end let converter env (f : Ground.t) = @@ -486,6 +624,9 @@ let converter env (f : Ground.t) = in Egraph.register env eq_node; Boolean.set_true env eq_node) + | { app = { builtin = Builtin.Array_map; _ }; _ } + when !extended_comb || !default_values -> + Theory.new_map env f | _ -> () let init env : unit =