diff --git a/src/plugins/wp/Lang.mli b/src/plugins/wp/Lang.mli index 4053e2e390833cd0c422266090946e3c6d300f0e..409bbedae1e94713f81b136a0655f6ddc058ec15 100644 --- a/src/plugins/wp/Lang.mli +++ b/src/plugins/wp/Lang.mli @@ -351,6 +351,7 @@ sig val add : sigma -> term -> term -> unit val add_map : sigma -> term Tmap.t -> unit val add_fun : sigma -> (term -> term) -> unit + val add_filter : sigma -> (term -> bool) -> unit end val e_subst : sigma -> term -> term diff --git a/src/plugins/wp/Letify.ml b/src/plugins/wp/Letify.ml index 66b14fda0524576c744bfb9c1e68eef3f7dd0351..309456f6f916e91fb745642862c3c81d29ddf7cb 100644 --- a/src/plugins/wp/Letify.ml +++ b/src/plugins/wp/Letify.ml @@ -237,7 +237,7 @@ struct def : term Vmap.t ; (* Definitions *) ceq : Ceq.t ; (* Variable Classes *) cst : term Tmap.t ; (* Constants *) - mutable mem : term Tmap.t array ; (* Memoization *) + mutable cache : F.sigma option ; } let empty = { @@ -247,7 +247,7 @@ struct ceq = Ceq.empty ; def = Vmap.empty ; cst = Tmap.empty ; - mem = Array.make 5 Tmap.empty ; + cache = None ; } let equal s1 s2 = @@ -257,35 +257,26 @@ struct let find x sigma = Vmap.find x sigma.def let iter f sigma = Vmap.iter f sigma.def - let rec m_apply sigma n (e:term) = + let lookup def (e:term) = match F.repr e with - | Fvar x -> - begin - try Vmap.find x sigma.def - with Not_found -> e - end - | _ -> - let ys = F.vars e in - if not (Vars.is_empty ys || Vars.intersect ys sigma.dall) - then e (* no subst *) - else if n < 5 then - begin - (* memoization *) - try Tmap.find e sigma.mem.(n) - with Not_found -> - let r = - try - if n > 0 then raise Not_found ; - Tmap.find e sigma.cst - with Not_found -> - F.QED.f_map (m_apply sigma) n e - in - sigma.mem.(n) <- Tmap.add e r sigma.mem.(n) ; r - end - else F.QED.f_map (m_apply sigma) n e + | Fvar x -> Vmap.find x def + | _ -> raise Not_found + + let filter domain (e:term) = + Vars.intersect (F.vars e) domain - let e_apply sigma e = m_apply sigma 0 e - let p_apply sigma p = F.p_bool (e_apply sigma (F.e_prop p)) + let subst sigma = + match sigma.cache with + | Some s -> s + | None -> + let s = Lang.sigma () in + F.Subst.add_fun s (lookup sigma.def) ; + F.Subst.add_map s sigma.cst ; + F.Subst.add_filter s (filter sigma.dall) ; + sigma.cache <- Some s ; s + + let e_apply sigma e = F.e_subst (subst sigma) e + let p_apply sigma p = F.p_subst (subst sigma) p (* Returns true if [x:=a] applied to [y:=b] raises a circularity *) let occur_check sigma x a = @@ -309,7 +300,7 @@ struct def = Vmap.add x e Vmap.empty ; ceq = add_ceq x e Ceq.empty ; cst = Tmap.empty ; - mem = [| Tmap.empty |] ; + cache = None ; } let add x e sigma = @@ -325,16 +316,14 @@ struct (fun e c cst -> if vmem x e then Tmap.add (e_apply sx e) c cst else cst) cst0 sigma.cst in - let cache = Array.make (Array.length sigma.mem) Tmap.empty in - cache.(0) <- cst1 ; { - mem = cache ; cst = cst1 ; def = def ; ceq = add_ceq x e sigma.ceq ; dvar = Vars.add x sigma.dvar ; dall = Vars.add x sigma.dall ; dcod = Vars.union (F.vars e) sigma.dcod ; + cache = None ; } let domain sigma = sigma.dvar @@ -351,20 +340,18 @@ struct with Not_found -> let cst = Tmap.add e c sigma.cst in let all = Vars.union (F.vars e) sigma.dall in - let cache = Array.make (Array.length sigma.mem) Tmap.empty in - cache.(0) <- cst ; { - mem = cache ; cst = cst ; dall = all ; dvar = sigma.dvar ; dcod = sigma.dcod ; def = sigma.def ; ceq = sigma.ceq ; + cache = None ; } let mem_lit l sigma = - try Tmap.find l sigma.mem.(0) == e_true + try F.Subst.get (subst sigma) l == e_true with Not_found -> false let add_lit l sigma = @@ -464,14 +451,11 @@ struct Format.fprintf fmt "@ @[%a := %a ;@]" F.pp_term (F.e_var x) F.pp_term e ) def ; - Array.iteri - (fun i w -> - Tmap.iter - (fun e m -> - Format.fprintf fmt "@ C%d: @[%a := %a ;@]" i - F.pp_term e F.pp_term m - ) w - ) sigma.mem ; + Tmap.iter + (fun e m -> + Format.fprintf fmt "@ C @[%a := %a ;@]" + F.pp_term e F.pp_term m + ) sigma.cst ; Format.fprintf fmt "@ @]}@]" ; end