Skip to content
Snippets Groups Projects
Commit 32966d53 authored by François Bobot's avatar François Bobot
Browse files

[Quant] reduce useless instantiation by storing previous substitution

parent 7ef73509
No related branches found
No related tags found
1 merge request!25[Quant] Substitute by prefering existing terms to new equal one
......@@ -251,3 +251,51 @@ module Ref = struct
let get t d = Context.Ref.get (Egraph.get_unsaved_env d t)
let set t d v = Context.Ref.set (Egraph.get_unsaved_env d t) v
end
module type Trie = sig
type 'a t
type key
val create: 'a Format.printer -> string -> 'a t
module List : sig
val set : 'a t -> _ Egraph.t ->key list -> 'a -> unit
val find_def : default:(Context.creator -> 'a) -> 'a t -> _ Egraph.t ->key list -> 'a
end
module Set : sig
type set
val set : 'a t -> _ Egraph.t -> set -> 'a -> unit
val find_def : default:(Context.creator -> 'a) -> 'a t -> _ Egraph.t -> set -> 'a
end
end
module Trie (S:Colibri2_popop_lib.Popop_stdlib.Datatype) : Trie
with type key := S.t and type Set.set := S.S.t = struct
module Trie = Context.Trie(S)
type 'a t = 'a Trie.t Env.Unsaved.t
let create : type a. a Colibri2_popop_lib.Pp.pp -> _ -> a t = fun pp name ->
let module M = struct
type t = a Trie.t
let name = name end
in
let key = Env.Unsaved.create (module M) in
let init d = Trie.create d in
let pp = Trie.pp pp in
Env.Unsaved.register ~init ~pp key;
key
module List = struct
let set t d l v = Trie.List.set (Egraph.get_unsaved_env d t) l v
let find_def ~default t d l = Trie.List.find_def ~default (Egraph.get_unsaved_env d t) l
end
module Set = struct
let set t d l v = Trie.Set.set (Egraph.get_unsaved_env d t) l v
let find_def ~default t d l = Trie.Set.find_def ~default (Egraph.get_unsaved_env d t) l
end
end
......@@ -87,3 +87,25 @@ module Ref: sig
val get: 'a t -> _ Egraph.t -> 'a
val set: 'a t -> _ Egraph.t -> 'a -> unit
end
module type Trie = sig
type 'a t
type key
val create: 'a Format.printer -> string -> 'a t
module List : sig
val set : 'a t -> _ Egraph.t ->key list -> 'a -> unit
val find_def : default:(Context.creator -> 'a) -> 'a t -> _ Egraph.t ->key list -> 'a
end
module Set : sig
type set
val set : 'a t -> _ Egraph.t -> set -> 'a -> unit
val find_def : default:(Context.creator -> 'a) -> 'a t -> _ Egraph.t -> set -> 'a
end
end
module Trie (S:Colibri2_popop_lib.Popop_stdlib.Datatype) : Trie
with type key := S.t and type Set.set := S.S.t
......@@ -460,17 +460,28 @@ module type HashtblWithDefault = sig
type key
val create : creator -> (creator -> 'a) -> 'a t
val pp : 'a Fmt.t -> 'a t Fmt.t
val set : 'a t -> key -> 'a -> unit
val find : 'a t -> key -> 'a
val change : ('a -> 'a) -> 'a t -> key -> unit
end
module HashtblWithDefault (S : Colibri2_popop_lib.Popop_stdlib.Datatype) :
HashtblWithDefault with type key := S.t = struct
module HashtblWithDefault (S : Colibri2_popop_lib.Popop_stdlib.Datatype) : sig
include HashtblWithDefault
val find_aux : 'a t -> key -> 'a Ref.t
end
with type key := S.t = struct
type 'a t = { h : 'a Ref.t S.H.t; def : creator -> 'a; creator : creator }
let create creator def = { h = S.H.create 5; def; creator }
let pp pp =
Fmt.(
iter_bindings ~sep:comma
(fun f t -> S.H.iter (fun k v -> f k (Ref.get v)) t.h)
(pair S.pp pp))
let find_aux t k =
match S.H.find_opt t.h k with
| Some r -> r
......@@ -566,3 +577,121 @@ module Clicket = struct
let pp ?sep pp fmt v = Vector.pp ?pp_sep:sep pp fmt v.v
end
type 'k fold = { fold : 'a. ('a -> 'k -> 'a) -> 'a -> 'a }
module type Trie = sig
type 'a t
type key
val create : creator -> 'a t
val pp : 'a Fmt.t -> 'a t Fmt.t
module List : sig
val set : 'a t -> key list -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> key list -> 'a
end
module Set : sig
type set
val set : 'a t -> set -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> set -> 'a
end
module Fold : sig
val set : 'a t -> key fold -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> key fold -> 'a
val memo : default:(creator -> 'a) -> 'a t -> key fold -> 'a
end
end
module Trie (S : Colibri2_popop_lib.Popop_stdlib.Datatype) :
Trie with type key := S.t and type Set.set := S.S.t = struct
module H = HashtblWithDefault (S)
type 'a node =
| Empty
| Value of 'a
| Node of 'a node H.t
| NodeValue of 'a * 'a node H.t
type 'a t = 'a node Ref.t
let create c = Ref.create c Empty
let pp pp fmt t =
let rec aux fmt c =
match c with
| Empty -> ()
| Value v -> pp fmt v
| Node h -> H.pp aux fmt h
| NodeValue (v, h) ->
pp fmt v;
H.pp aux fmt h
in
aux fmt (Ref.get t)
module Fold = struct
let find_aux t { fold } =
let aux acc k =
match Ref.get acc with
| Empty ->
let h = H.create (Ref.creator acc) (fun _ -> Empty) in
Ref.set acc (Node h);
H.find_aux h k
| Value x ->
let h = H.create (Ref.creator acc) (fun _ -> Empty) in
Ref.set acc (NodeValue (x, h));
H.find_aux h k
| Node h | NodeValue (_, h) -> H.find_aux h k
in
let acc = fold aux t in
acc
[@@inline]
let set t fold v =
let r = find_aux t fold in
match Ref.get r with
| Empty -> Ref.set r (Value v)
| Value v' -> if Base.phys_equal v v' then Ref.set r (Value v)
| Node h | NodeValue (_, h) -> Ref.set r (NodeValue (v, h))
[@@inline]
let find_def ~default t fold =
let r = find_aux t fold in
match Ref.get r with
| Empty | Node _ -> default (Ref.creator r)
| Value v | NodeValue (v, _) -> v
[@@inline]
let memo ~default t fold =
let r = find_aux t fold in
match Ref.get r with
| Empty ->
let v = default (Ref.creator r) in
Ref.set r (Value v);
v
| Node h ->
let v = default (Ref.creator r) in
Ref.set r (NodeValue (v, h));
v
| Value v | NodeValue (v, _) -> v
[@@inline]
end
module List = struct
let set t l v =
Fold.set t { fold = (fun f acc -> List.fold_left f acc l) } v
let find_def ~default t l =
Fold.find_def ~default t { fold = (fun f acc -> List.fold_left f acc l) }
end
module Set = struct
let set t l v = Fold.set t { fold = (fun f acc -> S.S.fold_left f acc l) } v
let find_def ~default t l =
Fold.find_def ~default t { fold = (fun f acc -> S.S.fold_left f acc l) }
end
end
......@@ -227,6 +227,7 @@ module type HashtblWithDefault = sig
type key
val create : creator -> (creator -> 'a) -> 'a t
val pp : 'a Fmt.t -> 'a t Fmt.t
val set : 'a t -> key -> 'a -> unit
val find : 'a t -> key -> 'a
val change : ('a -> 'a) -> 'a t -> key -> unit
......@@ -266,3 +267,36 @@ module type Clicket = sig
end
module Clicket : Clicket
type 'k fold = { fold : 'a. ('a -> 'k -> 'a) -> 'a -> 'a }
module type Trie = sig
type 'a t
type key
val create : creator -> 'a t
val pp : 'a Fmt.t -> 'a t Fmt.t
module List : sig
val set : 'a t -> key list -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> key list -> 'a
end
module Set : sig
type set
val set : 'a t -> set -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> set -> 'a
end
module Fold : sig
val set : 'a t -> key fold -> 'a -> unit
val find_def : default:(creator -> 'a) -> 'a t -> key fold -> 'a
val memo : default:(creator -> 'a) -> 'a t -> key fold -> 'a
(** find and add default if not present *)
end
end
module Trie (S : Colibri2_popop_lib.Popop_stdlib.Datatype) :
Trie with type key := S.t and type Set.set := S.S.t
......@@ -32,7 +32,7 @@ let add_trigger d t =
(fun acc pat -> Pattern.match_any_term d acc pat)
Pattern.init (t.pat :: t.pats)
in
Ground.Subst.S.iter (Trigger.instantiate d t) substs
Trigger.instantiate_many d t substs
let find_new_event d n (info : Info.t) n' (info' : Info.t) =
Debug.dprintf8 debug "Find_new_event %a %a %a %a" Node.pp n Info.pp info
......@@ -138,9 +138,7 @@ let process_inverted_path d n acc =
(fun acc ip -> InvertedPath.exec d acc Pattern.init n ip)
Trigger.M.empty acc
in
Trigger.M.iter
(fun tri substs -> Ground.Subst.S.iter (Trigger.instantiate d tri) substs)
acc
Trigger.M.iter (fun tri substs -> Trigger.instantiate_many d tri substs) acc
module Delayed_find_new_event = struct
let key =
......@@ -219,8 +217,8 @@ let attach d th =
Egraph.merge d (skolemize d e) n
| { binder = Exists; _ }, false | { binder = Forall; _ }, true ->
let triggers =
match Trigger.get_user_triggers th with
| [] -> Trigger.compute_top_triggers th
match Trigger.get_user_triggers d th with
| [] -> Trigger.compute_top_triggers d th
| triggers -> triggers
in
Debug.dprintf4 debug "[Quant] For %a adds %a"
......
......@@ -21,17 +21,96 @@
open Colibri2_popop_lib
open Common
module T = struct
(** Only for substitutions that have the same domain *)
module SubstTrie = struct
module TTerm = Context.Trie (Node)
module TTy = Context.Trie (Ground.Ty)
type 'a t = 'a TTy.t TTerm.t
let create c = TTerm.create c
let fold_of_term_map s : _ Context.fold =
{
fold =
(fun f acc -> Expr.Term.Var.M.fold_left (fun acc _ v -> f acc v) acc s);
}
let fold_of_ty_map s : _ Context.fold =
{
fold =
(fun f acc -> Expr.Ty.Var.M.fold_left (fun acc _ v -> f acc v) acc s);
}
let find_and_add (t : bool t) (s : Ground.Subst.t) =
let tv =
TTerm.Fold.memo
~default:(fun c -> TTy.create c)
t (fold_of_term_map s.term)
in
let not_set = ref false in
ignore
(TTy.Fold.memo
~default:(fun _ ->
not_set := true;
true)
tv (fold_of_ty_map s.ty));
!not_set
let find_def ~default (t : _ t) (s : Ground.Subst.t) =
let tv =
TTerm.Fold.memo
~default:(fun c -> TTy.create c)
t (fold_of_term_map s.term)
in
TTy.Fold.find_def ~default:(fun _ -> default) tv (fold_of_ty_map s.ty)
let set (t : _ t) (s : Ground.Subst.t) v =
let tv =
TTerm.Fold.memo
~default:(fun c -> TTy.create c)
t (fold_of_term_map s.term)
in
TTy.Fold.set tv (fold_of_ty_map s.ty) v
end
type inst_step = NotSeen | Delayed | Instantiated
module T : sig
type t = private {
id : int;
pat : Pattern.t;
pats : Pattern.t list;
checks : Pattern.t list;
form : Ground.ClosedQuantifier.t;
eager : bool;
substs : inst_step SubstTrie.t;
}
include Popop_stdlib.TaggedType with type t := t
val mk :
pat:Pattern.t ->
pats:Pattern.t list ->
checks:Pattern.t list ->
form:Ground.ClosedQuantifier.t ->
eager:bool ->
_ Egraph.t ->
t
end = struct
open! Base
type t = {
id : int;
pat : Pattern.t;
pats : Pattern.t list;
checks : Pattern.t list;
form : Ground.ClosedQuantifier.t;
eager : bool;
substs : inst_step SubstTrie.t;
}
[@@deriving eq, ord, hash]
let tag t = t.id
let pp fmt t =
Fmt.pf fmt "[%a, %a ( %a ) -> %a]" Pattern.pp t.pat
......@@ -39,17 +118,31 @@ module T = struct
t.pats
Fmt.(list ~sep:comma Pattern.pp)
t.checks Ground.ClosedQuantifier.pp t.form
let id_get, id_incr = Util.get_counter ()
let mk ~pat ~pats ~checks ~form ~eager d =
id_incr ();
{
id = id_get ();
pat;
pats;
checks;
form;
eager;
substs = SubstTrie.create (Egraph.context d);
}
end
include T
include Popop_stdlib.MkDatatype (T)
include (Popop_stdlib.MakeMSH (T) : Popop_stdlib.Datatype with type t := t)
let register_builtin_skipped_for_trigger, builtin_skipped_for_trigger =
let q = Base.Queue.create () in
( Base.Queue.enqueue q,
fun builtin -> Base.Queue.exists q ~f:(fun p -> p builtin) )
let compute_top_triggers (cq : Ground.ClosedQuantifier.t) =
let compute_top_triggers d (cq : Ground.ClosedQuantifier.t) =
let cq' = Ground.ClosedQuantifier.sem cq in
let tyvs = cq'.ty_vars in
let tvs = cq'.term_vars in
......@@ -157,13 +250,8 @@ let compute_top_triggers (cq : Ground.ClosedQuantifier.t) =
| [] -> acc
| a :: l ->
aux
({
pat = a;
pats = other @ l;
checks = pats_partial;
form = cq;
eager = true;
}
(mk d ~pat:a ~pats:(other @ l) ~checks:pats_partial ~form:cq
~eager:true
:: acc)
(a :: other) l
in
......@@ -184,13 +272,9 @@ let compute_top_triggers (cq : Ground.ClosedQuantifier.t) =
| [] -> acc
| a :: l ->
aux
({
pat = a;
pats = [];
checks = other @ l @ pats_partial;
form = cq;
eager = true;
}
(mk d ~pat:a ~pats:[]
~checks:(other @ l @ pats_partial)
~form:cq ~eager:true
:: acc)
(a :: other) l
in
......@@ -207,7 +291,7 @@ let compute_top_triggers (cq : Ground.ClosedQuantifier.t) =
pats_full_with_others;
pats_full_with_others
let compute_all_triggers (cq : Ground.ClosedQuantifier.t) =
let compute_all_triggers d (cq : Ground.ClosedQuantifier.t) =
let cq' = Ground.ClosedQuantifier.sem cq in
let tyvs = cq'.ty_vars in
let tvs = cq'.term_vars in
......@@ -252,19 +336,15 @@ let compute_all_triggers (cq : Ground.ClosedQuantifier.t) =
(fun (c, (sty, st)) ->
if Expr.Ty.Var.S.subset tyvs sty && Expr.Term.Var.S.subset tvs st then
Some
{
pat = Pattern.of_term ~subst:cq'.subst c;
pats = [];
form = cq;
eager = true;
checks = [];
}
(mk d
~pat:(Pattern.of_term ~subst:cq'.subst c)
~pats:[] ~form:cq ~eager:true ~checks:[])
else None)
pats
in
pats
let get_user_triggers (cq : Ground.ClosedQuantifier.t) =
let get_user_triggers d (cq : Ground.ClosedQuantifier.t) =
let cq' = Ground.ClosedQuantifier.sem cq in
let pats = Expr.Term.get_tag_list cq'.body Expr.Tags.triggers in
let tyvs = Expr.Ty.Var.S.of_list cq'.ty_vars in
......@@ -283,13 +363,7 @@ let get_user_triggers (cq : Ground.ClosedQuantifier.t) =
| [] -> acc
| a :: l ->
aux
({
pat = a;
pats = other @ l;
checks = [];
form = cq;
eager = true;
}
(mk d ~pat:a ~pats:(other @ l) ~checks:[] ~form:cq ~eager:true
:: acc)
(a :: other) l
in
......@@ -311,13 +385,15 @@ let add_trigger d t =
| Node _ -> ()
let instantiate_aux d tri subst =
SubstTrie.set tri.substs subst Instantiated;
let form = Ground.ClosedQuantifier.sem tri.form in
Debug.incr nb_instantiation;
let subst = Ground.Subst.distinct_union subst form.subst in
let n = Ground.convert ~subst d form.body in
if
Colibri2_stdlib.Debug.test_flag Colibri2_stdlib.Debug.stats
&& not (Egraph.is_registered d n)
&& not (Egraph.is_equal d n (Ground.ClosedQuantifier.node tri.form))
(* not (Egraph.is_registered d n) *)
then Debug.incr nb_new_instantiation;
Egraph.register d n;
Egraph.merge d n (Ground.ClosedQuantifier.node tri.form)
......@@ -359,34 +435,55 @@ end
let () = Events.register (module Delayed_instantiation)
let instantiate d tri subst =
let subst =
{
subst with
Ground.Subst.term =
Expr.Term.Var.M.map (Egraph.find_def d) subst.Ground.Subst.term;
}
let instantiate' d tri subst =
let show_debug debug =
Debug.dprintf9 debug
"[Quant] %a instantiation found %a, pat %a, checks:%a, eager:%b"
Ground.Subst.pp subst Ground.ClosedQuantifier.pp tri.form
Fmt.(list ~sep:comma Pattern.pp)
tri.pats
Fmt.(list ~sep:comma Pattern.pp)
tri.checks tri.eager
in
Debug.dprintf9 debug
"[Quant] %a instantiation found %a, pat %a, checks:%a, eager:%b"
Ground.Subst.pp subst Ground.ClosedQuantifier.pp tri.form
Fmt.(list ~sep:comma Pattern.pp)
tri.pats
Fmt.(list ~sep:comma Pattern.pp)
tri.checks tri.eager;
if
tri.eager
&& List.for_all
(fun pat ->
not (Node.S.is_empty (Pattern.check_term_exists d subst pat)))
tri.checks
then (
Debug.incr nb_eager_instantiation;
instantiate_aux d tri subst)
else (
Debug.dprintf0 debug "[Quant] Delayed";
Debug.incr nb_delayed_instantiation;
Events.new_pending_daemon d Delayed_instantiation.key (tri, subst))
match SubstTrie.find_def ~default:NotSeen tri.substs subst with
| Instantiated -> show_debug debug_full
| level_inst -> (
show_debug debug;
if
tri.eager
&& List.for_all
(fun pat ->
not (Node.S.is_empty (Pattern.check_term_exists d subst pat)))
tri.checks
then (
Debug.incr nb_eager_instantiation;
instantiate_aux d tri subst)
else
match level_inst with
| NotSeen ->
SubstTrie.set tri.substs subst Delayed;
Debug.dprintf0 debug "[Quant] Delayed";
Debug.incr nb_delayed_instantiation;
Events.new_pending_daemon d Delayed_instantiation.key (tri, subst)
| Delayed | Instantiated ->
Debug.dprintf0 debug "[Quant] Already delayed")
let find_repr_subst d subst =
{
subst with
Ground.Subst.term =
Expr.Term.Var.M.map (Egraph.find_def d) subst.Ground.Subst.term;
}
let instantiate d t subst = instantiate' d t (find_repr_subst d subst)
let instantiate_many d t substs =
let substs =
Ground.Subst.S.fold_left
(fun acc s -> Ground.Subst.S.add (find_repr_subst d s) acc)
Ground.Subst.S.empty substs
in
Ground.Subst.S.iter (instantiate' d t) substs
let match_ d tri n =
Debug.dprintf4 debug "[Quant] match %a %a" pp tri Node.pp n;
......
......@@ -20,7 +20,23 @@
(** Trigger *)
type t = {
module SubstTrie : sig
type 'a t
val find_and_add : bool t -> Ground.Subst.t -> bool
(** Return if the element was present before adding it *)
val find_def : default:'a -> 'a t -> Ground.Subst.t -> 'a
(** Return if the element was present before adding it *)
val set : 'a t -> Ground.Subst.t -> 'a -> unit
(** Return if the element was present before adding it *)
end
type inst_step = NotSeen | Delayed | Instantiated
type t = private {
id : int;
pat : Pattern.t; (** The pattern on which to wait for a substitution *)
pats : Pattern.t list;
(** The other ones used to obtain a complete
......@@ -29,18 +45,19 @@ type t = {
form : Ground.ClosedQuantifier.t; (** the body of the formula *)
eager : bool;
(** If it should be eagerly applied, otherwise wait for LastEffort *)
substs : inst_step SubstTrie.t;
}
include Colibri2_popop_lib.Popop_stdlib.Datatype with type t := t
val compute_top_triggers : Ground.ClosedQuantifier.t -> t list
val compute_top_triggers : _ Egraph.t -> Ground.ClosedQuantifier.t -> t list
(** Compute triggers, that should only add logical connective or equalities are
new terms *)
val compute_all_triggers : Ground.ClosedQuantifier.t -> t list
val compute_all_triggers : _ Egraph.t -> Ground.ClosedQuantifier.t -> t list
(** Compute all the triggers whose patterns contain all the variables of the formula *)
val get_user_triggers : Ground.ClosedQuantifier.t -> t list
val get_user_triggers : _ Egraph.t -> Ground.ClosedQuantifier.t -> t list
(** return the triggers given by the user *)
val env_vars : t Datastructure.Push.t
......@@ -62,6 +79,8 @@ val instantiate : Egraph.wt -> t -> Ground.Subst.t -> unit
* at last effort otherwise
*)
val instantiate_many : Egraph.wt -> t -> Ground.Subst.S.t -> unit
val match_ : Egraph.wt -> t -> Node.t -> unit
(** [match_ d tri n] match the pattern of [tri] with [n] and instantiate [tri]
with the resulting substitutions *)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment