diff --git a/src/arith.ml b/src/arith.ml index 1632fb6626199acef7935bea29ece2466018d575..5d06469a449edd5cce65d7bf958f70cf40aa0670 100644 --- a/src/arith.ml +++ b/src/arith.ml @@ -38,16 +38,35 @@ let mult_cst c p1 = if Q.equal Q.zero c then cst Q.zero else {cst = Q.mul c p1.cst; poly = poly_mult c p1.poly} -let poly_apply2 f m1 m2 = - Cl.M.union (fun _ c1 c2 -> - let c3 = f c1 c2 in - if Q.equal Q.zero c3 then None else Some c3) m1 m2 -let apply2 f p1 p2 = - {cst = f p1.cst p2.cst; poly = poly_apply2 f p1.poly p2.poly} - -let add = apply2 Q.add -let sub = apply2 Q.sub +let none_zero c = if Q.equal Q.zero c then None else Some c + +(** Warning Cl.M.union can be used only for defining an operation [op] + that verifies [op 0 p = p] and [op p 0 = p] *) +let add p1 p2 = + let poly_add m1 m2 = + Cl.M.union (fun _ c1 c2 -> none_zero (Q.add c1 c2)) m1 m2 + in + {cst = Q.add p1.cst p2.cst; poly = poly_add p1.poly p2.poly} + +let sub p1 p2 = + let poly_sub m1 m2 = + Cl.M.union_merge (fun _ c1 c2 -> + match c1 with + | None -> Some (Q.neg c2) + | Some c1 -> none_zero (Q.sub c1 c2)) + m1 m2 in + {cst = Q.sub p1.cst p2.cst; poly = poly_sub p1.poly p2.poly} + +let x_p_cy p1 c p2 = + let f a b = Q.add a (Q.mul c b) in + let poly m1 m2 = + Cl.M.union_merge (fun _ c1 c2 -> + match c1 with + | None -> Some (Q.mul c c2) + | Some c1 -> none_zero (f c1 c2)) + m1 m2 in + {cst = f p1.cst p2.cst; poly = poly p1.poly p2.poly} module Th = struct type repr = t @@ -63,28 +82,28 @@ module Th = struct n.poly (Hashtbl.hash n.cst * 27) let print print_cl fmt v = - Format.fprintf fmt "(@["; + Format.fprintf fmt "@["; let print_mono k v fmt = if Q.equal Q.one v then Format.fprintf fmt "@[%a+@]@," print_cl k else Format.fprintf fmt "@[%a%a+@]@," Q.pp_print v print_cl k; fmt in ignore (Cl.M.fold print_mono v.poly fmt); - Format.fprintf fmt "%a@])" Q.pp_print v.cst + Format.fprintf fmt "%a@]" Q.pp_print v.cst let leaves t = Cl.M.map (fun _ -> ()) t.poly let embed d cl = + Debug.dprintf debug "[Arith] @[embed cl=%a@]@." + (UnionFind.print_cl (E.Delayed.hack_env d)) cl; match X.extract (E.Delayed.repr d cl) with | None -> monome Q.one cl | Some r -> r - (* let subst d s t = *) - (* let t' = {cst = t.cst; poly = Cl.M.set_diff t.poly s} in *) - (* let t' = Cl.M.fold2_inter (fun _ c p poly -> *) - (* apply2 (fun c1 c2 -> Q.add c1 (Q.mul c c2)) poly (embed d p) *) - (* ) t.poly s t' in *) - (* E.Delayed.normalize d (X.make t') *) + let add d (v,cl) = + Cl.S.iter (fun leave -> + E.Delayed.depend_repr d cl leave) (leaves v) + let normalize d x = (** cst = 0 and one empty monome *) if Q.equal Q.zero x.cst && Cl.M.is_num_elt 1 x.poly then @@ -95,9 +114,12 @@ module Th = struct let subst d s t = let t' = {cst = t.cst; poly = Cl.M.set_diff t.poly s} in - let t' = Cl.M.fold2_inter (fun _ c p poly -> - apply2 (fun c1 c2 -> Q.add c1 (Q.mul c c2)) poly (embed d p) - ) t.poly s t' in + let t' = Cl.M.fold2_inter + (fun _ c p poly -> x_p_cy poly c (embed d p)) + t.poly s t' in + Debug.dprintf debug "[Arith] @[subst t=%a@ t'=%a@]@." + (X.print (E.Delayed.hack_env d)) t + (X.print (E.Delayed.hack_env d)) t'; normalize d t' @@ -112,11 +134,14 @@ module Th = struct else raise E.Contradiction else let x,q = Cl.M.choose p.poly in + Debug.dprintf debug "[Arith] @[solve p=%a@]@." + (X.print (E.Delayed.hack_env d)) p; let p = {p with poly = Cl.M.remove x p.poly} in assert ( not (Q.equal Q.zero q) ); - let sr = normalize d (mult_cst (Q.inv q) p) in + let sr = normalize d (mult_cst (Q.inv (Q.neg q)) p) in let s = Cl.M.singleton x sr in - Debug.dprintf debug "[Arith] @[solve x=%a@, sr=%a@]@." + Debug.dprintf debug "[Arith] @[solve q=%a x=%a@, sr=%a@]@." + Q.pp_print q (UnionFind.print_cl (E.Delayed.hack_env d)) x (UnionFind.print_cl (E.Delayed.hack_env d)) sr; let sp1 = subst d s p1 in @@ -126,6 +151,10 @@ module Th = struct (UnionFind.print_cl (E.Delayed.hack_env d)) sp2; Cl.equal sp1 sp2 ); s, sp1 + + let subst d s t = + assert (not (Cl.M.is_empty (Cl.M.set_inter t.poly s))); + subst d s t end end @@ -133,24 +162,24 @@ end module M = E.MTerm.Make(Th) -let basecl t r = E.MTerm.UnionFind.basecl t (M.make r) +let embed t cl = + match M.extract (E.repr t cl) with + | None -> monome Q.one cl + | Some r -> r + +let basecl t x = (** cst = 0 and one empty monome *) + if Q.equal Q.zero x.cst && Cl.M.is_num_elt 1 x.poly then + let cl,k = Cl.M.choose x.poly in + if Q.equal Q.one k then cl, t + else E.MTerm.UnionFind.basecl t (M.make x) + else E.MTerm.UnionFind.basecl t (M.make x) type env = UnionFind.t let cst t c = basecl t (cst c) let add t cl1 cl2 = - let r = match M.extract (E.repr t cl1), M.extract (E.repr t cl2) with - | None , None -> - let r1 = monome Q.one cl1 in - let r2 = monome Q.one cl2 in - add r1 r2 - | Some r1, None -> - let r2 = monome Q.one cl2 in - add r1 r2 - | None, Some r2 -> - let r1 = monome Q.one cl1 in - add r1 r2 - | Some r1, Some r2 -> - add r1 r2 in - basecl t r + basecl t (add (embed t cl1) (embed t cl2)) + +let mult_cst t cst cl = + basecl t (mult_cst cst (embed t cl)) diff --git a/src/arith.mli b/src/arith.mli index 6205e58dc2d3be513a46d4de047ed39ff23d0bac..8a62c1f34c8d2930dd2b0a79dc1c018e09db9afa 100644 --- a/src/arith.mli +++ b/src/arith.mli @@ -30,4 +30,5 @@ type env = UnionFind.t val cst : env -> Q.t -> cl * env val add : env -> cl -> cl -> cl * env +val mult_cst : env -> Q.t -> cl -> cl * env diff --git a/src/egraph_simple.ml b/src/egraph_simple.ml index 29b8333f613f9e04ea981ffcf39eabbe9075b592..121990dcbf841cd3a6783a75168b8593355c0366 100644 --- a/src/egraph_simple.ml +++ b/src/egraph_simple.ml @@ -40,6 +40,7 @@ include Del let dont_clean_use = false +exception Unimplemented of string open Stdlib open Shuffle @@ -98,19 +99,39 @@ let rec normalize env cl = else let arg,env = normalize env arg in (false,env), arg in - let (already,env),leaves = + let (_already,env),leaves = Cl.M.mapi_fold is_normalized leaves (true,env) in - if already then - (** the leaves are already normalized *) - let env = Cl.M.fold (fun _ norm env -> - UnionFind.add_use env cl v norm) leaves env in - let env = UnionFind.set_repr env cl v in - cl,env - else - (** subst must return normalized term *) - let t = ref env in - let cl = MTerm.subst t leaves v in - cl,!t + let env = UnionFind.set_repr env cl v in + (** subst must return a normalized term, except if its cl *) + let t = ref env in + let cl' = MTerm.subst t leaves v in + let env = !t in + if UnionFind.fold_use (fun _ _ _ -> true) env cl false then + raise (Unimplemented "normalization with child that use itself"); + let env = + if Cl.equal cl cl' + then + (* Cl.M.fold (fun _ norm env -> *) + (* UnionFind.add_use env cl v norm) leaves env *) + let env = Cl.M.fold (fun _ norm env -> + UnionFind.add_use env cl v norm) leaves env in + Debug.dprintf debug "[EGraph] AddUse @[%a@]@." Cl.print cl; + let t = ref env in + MTerm.add t (v,cl); + !t + else env in + cl',env + (* if already then *) + (* (\** the leaves are already normalized *\) *) + (* let env = Cl.M.fold (fun _ norm env -> *) + (* UnionFind.add_use env cl v norm) leaves env in *) + (* let env = UnionFind.set_repr env cl v in *) + (* cl,env *) + (* else *) + (* (\** subst must return normalized term *\) *) + (* let t = ref env in *) + (* let cl = MTerm.subst t leaves v in *) + (* cl,!t *) in let env = UnionFind.union_force env cl cl' in Debug.dprintf debug @@ -153,9 +174,11 @@ let unuse env cl v = (* no leaves so can't have been used in the first place *) else let remove arg env = - Debug.dprintf debug " Remove in @[%a@]@\n" + Debug.dprintf debug "[EGraph] @[Remove in @[%a@]@]@." (print_cl env) arg; - UnionFind.remove_use env cl v arg + let env = UnionFind.remove_use_repr env cl v arg in + let env = UnionFind.remove_use env cl v arg in + env in let env = Cl.S.fold remove leaves env in env @@ -183,10 +206,10 @@ exception ExitEnv of UnionFind.t let rec equal_aux queue env cl1 cl2 = let cl1 = UnionFind.find env cl1 in let cl2 = UnionFind.find env cl2 in - if Cl.equal cl1 cl2 then env else begin Debug.dprintf debug "[EGraph] Equal_Aux @[@[%a@] =>@ @[%a@]@]@." (print_cl env) cl1 (print_cl env) cl2; + if Cl.equal cl1 cl2 then env else begin (* s1 -> s2 *) let fold s1 s2 clparent_old vparent env = Debug.dprintf debug @@ -197,10 +220,11 @@ let rec equal_aux queue env cl1 cl2 = (** We use the version in the use map because it must have s1 as leaves *) let clparent',env = subst env s1 s2 vparent in - Debug.dprintf debug - "[EGraph] @[%a[@[%a@]->@[%a@]]@]@." - (print_cl env) clparent (print_cl env) s1 (print_cl env) s2; let clparent',env = normalize env clparent' in + Debug.dprintf debug + "[EGraph] @[%a[@[%a@]->@[%a@]]@] = @[%a@]@." + (print_cl env) clparent (print_cl env) s1 (print_cl env) s2 + (print_cl env) clparent'; let env = if dont_clean_use || MTerm.equal vparent (repr env clparent') then env else @@ -234,30 +258,45 @@ let rec equal_aux queue env cl1 cl2 = cl1,cl2,env end in + Debug.dprintf debug "[EGraph] @[Use dep %a@]@." (print_cl env) cl1; let env = UnionFind.fold_use (fold cl1 cl2) env cl1 env in + Debug.dprintf debug "[EGraph] @[Use repr dep %a@]@." (print_cl env) cl2; + let env = UnionFind.fold_depend_repr (fold cl2 cl2) env cl2 env in if Queue.is_empty queue then env else let (cl1,cl2) = Queue.pop queue in equal_aux queue env cl1 cl2 end -let rec equal env t1 t2 = +let rec equal_solve env t1 t2 = let t1,t2 = shufflep (t1,t2) in Debug.dprintf debug - "[EGraph] Equal @[@[%a@] =>@ @[%a@]@]@." + "[EGraph] Equal_solve @[@[%a@] =>@ @[%a@]@]@." (print_cl env) t1 (print_cl env) t2; let queue = Queue.create () in let d = ref env in let s, cl = MTerm.solve d (repr env t1,t1) (repr env t2,t2) in let env = !d in + Debug.dprintf debug + "[EGraph] Equal_solve 1)@."; let env = equal_aux queue env t1 cl in + Debug.dprintf debug + "[EGraph] Equal_solve 2)@."; let env = equal_aux queue env t2 cl in - MTerm.Cl.M.fold (fun t1 t2 env -> equal env t1 t2) s env + Debug.dprintf debug + "[EGraph] Equal_solve 3)@."; + MTerm.Cl.M.fold (fun t1 t2 env -> equal_solve env t1 t2) s env -let equal env t1 t2 = - assert (UnionFind.is_added env t1); - assert (UnionFind.is_added env t2); - equal env t1 t2 +let equal env cl1 cl2 = + assert (UnionFind.is_added env cl1); + assert (UnionFind.is_added env cl2); + Debug.dprintf debug + "[EGraph] Equal @[@[%a@] ==@ @[%a@]@]@." + (print_cl env) cl1 (print_cl env) cl2; + let cl1 = UnionFind.find env cl1 in + let cl2 = UnionFind.find env cl2 in + if Cl.equal cl1 cl2 then env + else equal_solve env cl1 cl2 let () = Exn_printer.register (fun fmt exn -> match exn with @@ -274,6 +313,14 @@ module Delayed = struct let cl,env = normalize env cl in t := env; cl + + exception BadTheory + let depend_repr t cl1 cl2 = + let v = try repr t cl1 with NoRepr -> raise BadTheory in + let env = !t in + let env = UnionFind.depend_repr env cl1 v cl2 in + t := env + let hack_env d = !d end diff --git a/src/egraph_simple.mli b/src/egraph_simple.mli index d52adaf6ce563cfcbc8791c25bf5be429e747957..2930fc462e6d7663db68e5dcbd2362a0227f9b79 100644 --- a/src/egraph_simple.mli +++ b/src/egraph_simple.mli @@ -41,6 +41,8 @@ exception Contradiction module Delayed : sig val repr : delayed -> MTerm.cl -> MTerm.value val normalize : delayed -> MTerm.value -> MTerm.cl + val depend_repr : delayed -> MTerm.cl -> MTerm.cl -> unit + (** cl1 depend on cl2 *) (** Hack *) val hack_env: delayed -> env diff --git a/src/term.ml b/src/term.ml index d8d7a37c23aa4aeaa39c36d88f66e6a938053c5d..2be8c0144990fc1c6ed3ecca6e3ec86303a72c05 100644 --- a/src/term.ml +++ b/src/term.ml @@ -46,6 +46,10 @@ sig val set_repr: t -> cl -> value -> t val print_cl: t -> Format.formatter -> cl -> unit val print_value: t -> Format.formatter -> value -> unit + val depend_repr: t -> cl -> value -> cl -> t + val fold_depend_repr : (cl -> value -> 'a -> 'a) -> t -> cl -> 'a -> 'a + val remove_use_repr : t -> cl -> value -> cl -> t + val add_use : t -> cl -> value -> cl -> t val remove_use : t -> cl -> value -> cl -> t @@ -60,7 +64,9 @@ end val leaves: value -> Cl.S.t val subst: delayed -> cl Cl.M.t -> value -> cl +val add: delayed -> value * cl -> unit val solve: delayed -> value * cl -> value * cl -> cl Cl.M.t * cl +(* val normalize: cl Cl.M.t -> value -> value *) type thkind = @@ -87,6 +93,9 @@ module type Th = sig val subst: delayed -> cl Cl.M.t -> repr -> cl (** todo cl? et UnionFind.delayed? *) (** Normally substitute only existing leaves *) + (* val normalize: cl Cl.M.t -> repr -> repr *) + + val add: delayed -> repr * cl -> unit val solve: delayed -> repr * cl -> cl -> cl Cl.M.t * cl end @@ -143,6 +152,8 @@ and ModType : sig val leaves: repr -> Cl.S.t val subst: delayed -> Cl.t Cl.M.t -> repr -> Cl.t + val add: delayed -> repr * Cl.t -> unit + (* val normalize: Cl.t Cl.M.t -> repr -> repr *) val solve: delayed -> repr * Cl.t -> Cl.t -> Cl.t Cl.M.t * Cl.t val id: int end @@ -160,6 +171,8 @@ end = struct val leaves: repr -> Cl.S.t val subst: delayed -> Cl.t Cl.M.t -> repr -> Cl.t + val add: delayed -> repr * Cl.t -> unit + (* val normalize: Cl.t Cl.M.t -> repr -> repr *) val solve: delayed -> repr * Cl.t -> Cl.t -> Cl.t Cl.M.t * Cl.t val id: int end @@ -225,6 +238,19 @@ let subst d s n = let module Th = (val th) in Th.subst d s v +let add d (n,cl) = + match n with + | Value(th,v) -> + let module Th = (val th) in + Th.add d (v,cl) + + +(* let normalize s n = *) +(* match n with *) +(* | Value(th,v) -> *) +(* let module Th = (val th) in *) +(* Value(th,Th.normalize s v) *) + let solve d (n,ncl) (m,mcl) = match n with | Value(nth,nv) -> @@ -263,6 +289,7 @@ struct (** stocking value for others *) repr : value Cl.M.t; use : value Cl.M.t Cl.M.t; + use_repr : value Cl.M.t Cl.M.t; added : Cl.S.t; } @@ -277,7 +304,7 @@ struct and print_cl' nb env fmt k = if nb = 0 then Cl.print fmt k else try - Format.fprintf fmt "%a" + Format.fprintf fmt "[%a]" (print_value (nb-1) env) (repr Exit env k) with Exit -> Format.pp_print_string fmt "<NoRepr>" @@ -291,6 +318,7 @@ struct { eq = Cl.M.empty; repr = Cl.M.empty; use = Cl.M.empty; + use_repr = Cl.M.empty; added = Cl.M.empty; } @@ -330,6 +358,9 @@ struct let fold_use f t cl acc = Cl.M.fold f (Cl.M.find_def Cl.M.empty cl t.use) acc + let fold_depend_repr f t cl acc = + Cl.M.fold f (Cl.M.find_def Cl.M.empty cl t.use_repr) acc + let add_use t cl v arg = {t with use = Cl.M.change (function @@ -337,6 +368,14 @@ struct | Some m -> assert( not (Cl.M.mem cl m) ); Some (Cl.M.add cl v m)) arg t.use} + + let depend_repr t cl v arg = + {t with use_repr = + Cl.M.change (function + | None -> Some (Cl.M.singleton cl v) + | Some m -> assert( not (Cl.M.mem cl m) ); + Some (Cl.M.add cl v m)) arg t.use_repr} + let remove_use t cl v arg = {t with use = Cl.M.change (function @@ -345,6 +384,17 @@ struct let m = Cl.M.remove cl m in if Cl.M.is_empty m then None else Some m) arg t.use} + (** not every one has a use repr that's why its different from remove use + just used for cleaning + *) + let remove_use_repr t cl v arg = + {t with use_repr = + Cl.M.change (function + | None -> None (*invalid_arg "can't remove what is not here"*) + | Some m -> (* assert( equal (Cl.M.find cl m) v ); *) + let m = Cl.M.remove cl m in + if Cl.M.is_empty m then None else Some m) arg t.use_repr} + let is_added t cl = Cl.S.mem cl t.added let added t cl = {t with added = Cl.S.add cl t.added} @@ -367,6 +417,14 @@ struct color=\"gray\"];@\n" Cl.print k Cl.print r) s) env.use; + Cl.M.iter + (fun k s -> + Cl.M.iter (fun r _ -> + Format.fprintf fmt "@[<h>\"%a\"@]@ -> @[<h>\"%a\"@]@ \ + [@[style=\"dashed\",@ constraint=false@],@ \ + color=\"blue\"];@\n" + Cl.print k Cl.print r) s) + env.use_repr; Cl.M.iter (fun k _ -> Format.fprintf fmt "@[<h>\"%a\"@]@ [@[label=@[<h>\"%a\"@]@]];@\n" @@ -410,6 +468,8 @@ module type Th = sig val leaves: repr -> Cl.S.t val subst: delayed -> Cl.t Cl.M.t -> repr -> cl + val add: delayed -> repr * Cl.t -> unit + (* val normalize: cl Cl.M.t -> repr -> repr *) val solve: delayed -> repr * cl -> cl -> Cl.t Cl.M.t * cl end diff --git a/src/term.mli b/src/term.mli index f41023633358b6332cac50dada0fabf999206fbe..0b6662133d9df6b4d2a828d10a3ac0a85109cb69 100644 --- a/src/term.mli +++ b/src/term.mli @@ -47,6 +47,9 @@ sig val set_repr: t -> cl -> value -> t val print_cl: t -> Format.formatter -> cl -> unit val print_value: t -> Format.formatter -> value -> unit + val depend_repr: t -> cl -> value -> cl -> t + val fold_depend_repr : (cl -> value -> 'a -> 'a) -> t -> cl -> 'a -> 'a + val remove_use_repr : t -> cl -> value -> cl -> t val add_use : t -> cl -> value -> cl -> t val remove_use : t -> cl -> value -> cl -> t @@ -61,8 +64,9 @@ end val leaves: value -> Cl.S.t val subst: delayed -> cl Cl.M.t -> value -> cl +val add: delayed -> value * cl -> unit val solve: delayed -> value * cl -> value * cl -> cl Cl.M.t * cl - +(* val normalize: cl Cl.M.t -> value -> value *) type thkind = | Polite (** really? don't work with the representant *) @@ -85,9 +89,11 @@ module type Th = sig val equal: repr -> repr -> bool val leaves: repr -> Cl.S.t - val subst: delayed -> cl Cl.M.t -> repr -> cl - (** todo cl? et UnionFind.delayed? *) + val subst: delayed -> cl Cl.M.t -> repr -> cl (** Normally substitute only existing leaves *) + (* val normalize: cl Cl.M.t -> repr -> repr *) + + val add: delayed -> repr * cl -> unit val solve: delayed -> repr * cl -> cl -> cl Cl.M.t * cl end diff --git a/src/uninterp.ml b/src/uninterp.ml index 8d80a6d49f324021fdd61e52bce8e28895e390c0..e335ee32d16d34ede407819b5b6780239ec67964 100644 --- a/src/uninterp.ml +++ b/src/uninterp.ml @@ -62,6 +62,8 @@ module Th = struct | App(f,g) -> Cl.S.add g (Cl.S.singleton f) + let add d _ = () + (* No more classes to merge than the two given class themselves *) let solve _ _ v = Cl.M.empty,v diff --git a/src/util/extmap.ml b/src/util/extmap.ml index 1ea3dc46709167ea07354967ba56e725d844382b..0aae9c9c047126147d30201edced77e4d02943fc 100644 --- a/src/util/extmap.ml +++ b/src/util/extmap.ml @@ -77,6 +77,9 @@ module type S = val add_new : exn -> key -> 'a -> 'a t -> 'a t val keys: 'a t -> key list val values: 'a t -> 'a list + val union_merge: + (key -> 'a option -> 'b -> 'a option) -> 'a t -> 'b t -> 'a t + val height: 'a t -> int type 'a enumeration val val_enum : 'a enumeration -> (key * 'a) option val start_enum : 'a t -> 'a enumeration @@ -480,6 +483,25 @@ struct (union f r1 r2) end + let rec union_merge f s1 s2 = + match (s1, s2) with + (Empty, Empty) -> Empty + | (t1,Empty) -> t1 + | (Node (l1, v1, d1, r1, h1), _) when h1 >= height s2 -> + let (l2, d2, r2) = split v1 s2 in + begin match d2 with + | None -> join (union_merge f l1 l2) v1 d1 (union_merge f r1 r2) + | Some d2 -> + concat_or_join (union_merge f l1 l2) v1 (f v1 (Some d1) d2) + (union_merge f r1 r2) + end + | (_, Node (l2, v2, d2, r2, _h2)) -> + let (l1, d1, r1) = split v2 s1 in + concat_or_join (union_merge f l1 l2) v2 (f v2 d1 d2) + (union_merge f r1 r2) + | _ -> + assert false + let rec inter f s1 s2 = match (s1, s2) with @@ -647,6 +669,10 @@ struct fold (fun _ _ n -> if n < 0 then raise Exit else n-1) m n = 0 with Exit -> false + let height = function + | Empty -> 0 + | Node(_,_,_,_,h) -> h + let start_enum s = cons_enum s End let val_enum = function diff --git a/src/util/extmap.mli b/src/util/extmap.mli index b0b3ef73e76d143f490f818812ca4b72d378d72d..81fec59fd7397f359ff33a4441cbc461aa673dbd 100644 --- a/src/util/extmap.mli +++ b/src/util/extmap.mli @@ -278,6 +278,14 @@ module type Map = sig to the ordering [Ord.compare] of the keys, where [Ord] is the argument given to {!Map.Make}. *) + val union_merge: + (key -> 'a option -> 'b -> 'a option) -> 'a t -> 'b t -> 'a t + (** Between union for the first argument and merge for the second + argument *) + + val height: 'a t -> int + (** height of the underlying tree, can be used for optimisations *) + (** enumeration: zipper style *) type 'a enumeration diff --git a/tests/tests_arith.ml b/tests/tests_arith.ml index 04acfd648e289898ff3bcf933ad2cb60bd10d084..e1ce701d56adf3c357c3ac08a517cd7729c0302f 100644 --- a/tests/tests_arith.ml +++ b/tests/tests_arith.ml @@ -16,19 +16,64 @@ let b,env = add env b let c,env = add env c let _1,env = Arith.cst env Q.one +let _2,env = Arith.cst env (Q.of_int 2) let a1,env = Arith.add env a _1 let b1,env = Arith.add env b _1 let _1,env = add env _1 +let _2,env = add env _2 let a1,env = add env a1 let b1,env = add env b1 -let solve () = +let a2,env = Arith.add env _1 a1 +let b2,env = Arith.add env b _2 + +let a2,env = add env a2 +let b2,env = add env b2 + +let _2a2,env = Arith.add env a a2 +let _2b2,env = Arith.mult_cst env (Q.of_int 2) b1 + +let _2a2,env = add env _2a2 +let _2b2,env = add env _2b2 + + +let solve1 () = let env = equal env a1 b1 in - Uf.exportdot "arith.dot" env; assert_bool "" (is_equal env a b) -let basic = "Basic" >::: ["a+1 = b+1 => a = b" >:: solve;] +let solve2 () = + let env = equal env a1 b1 in + assert_bool "" (is_equal env a2 b2) + +let solve3 () = + let env = equal env a2 b1 in + assert_bool "" (is_equal env a1 b) + +let solve4 () = + Format.eprintf "[Test] 0@."; + Uf.exportdot "solve4_0.dot" env; + let env = equal env a2 b1 in + Format.eprintf "[Test] 1@."; + Uf.exportdot "solve4_1.dot" env; + let env = equal env a _2 in + Format.eprintf "[Test] 2@."; + Uf.exportdot "solve4_2.dot" env; + assert_bool "" (not (is_equal env b _2)); + let _3,env = Arith.cst env (Q.of_int 3) in + Format.eprintf "[Test] 3@."; + Uf.exportdot "solve4_3.dot" env; + let _3,env = add env _3 in + Format.eprintf "[Test] 4@."; + Uf.exportdot "solve4_4.dot" env; + assert_bool "" (is_equal env b _3) + + +let basic = "Basic" >::: ["a+1 = b+1 => a = b " >:: solve1; + "a+1 = b+1 => a+2 = b+2" >:: solve2; + "a+2 = b+1 => a+1 = b" >:: solve3; + "a+2 = b+1 => a = 2 => b = 3" >:: solve4] + let tests = TestList [basic]