diff --git a/src_colibri2/core/interp.ml b/src_colibri2/core/interp.ml index 4cff4487c10d6d9137b3f2bcc7e1643b9241f266..c93668b6e3cb79ff335d191fa7c525a0f07f38d1 100644 --- a/src_colibri2/core/interp.ml +++ b/src_colibri2/core/interp.ml @@ -58,30 +58,59 @@ module Data = struct end module SeqLim = struct - type 'a t = 'a Sequence.t + type 'a t = int -> 'a Sequence.t let of_seq d s = match Egraph.get_env d Data.env with - | Before -> s + | Before -> fun _ -> s | During d -> - Sequence.unfold_with s ~init:0 ~f:(fun i x -> - if i <= d.limit then Sequence.Step.Yield (x, i + 1) - else ( - d.limitreached := true; - Sequence.Step.Done)) + fun j -> + if j >= d.limit then ( + d.limitreached := true; + Sequence.take s 1) + else + Sequence.unfold_with s ~init:0 ~f:(fun i x -> + if i <= d.limit then Sequence.Step.Yield (x, i + 1) + else ( + d.limitreached := true; + Sequence.Step.Done)) + + let map (x : 'a t) ~f i = Sequence.map ~f (x i) - let map = Sequence.map + let unfold_with x ~init ~f i = Sequence.unfold_with (x i) ~init ~f - let unfold_with = Sequence.unfold_with + let ( let+ ) x y i = Sequence.( >>| ) (x i) y - let ( let+ ) = Sequence.( >>| ) + let ( let* ) t f i = Sequence.bind (t i) ~f:(fun x -> f x i) - let ( let* ) t f = Sequence.bind t ~f + let ( and* ) x y i = Sequence.cartesian_product (x i) (y i) - let ( and* ) = Sequence.cartesian_product + let incr_depth d x i = + match Egraph.get_env d Data.env with + | Before -> x i + | During d -> + if i >= d.limit then ( + d.limitreached := true; + Sequence.take (x (i + 1)) 1) + else x (i + 1) end -module Register = struct +module Register : sig + val check : Egraph.t -> (Egraph.t -> Ground.t -> bool option) -> unit + + val node : + Egraph.t -> + ((Egraph.t -> Nodes.Node.t -> Nodes.Value.t SeqLim.t) -> + Egraph.t -> + Nodes.Node.t -> + Nodes.Value.t SeqLim.t option) -> + unit + + val ty : + Egraph.t -> + (Egraph.t -> Ground.Ty.t -> Nodes.Value.t SeqLim.t option) -> + unit +end = struct open Data let check d f = Datastructure.Queue.push check d f @@ -91,6 +120,15 @@ module Register = struct let ty d f = Datastructure.Queue.push ty d f end +let spy_sequence msg s i = + if false then + Sequence.mapi + ~f:(fun i x -> + Fmt.epr "@.%s %i: %a@." msg i Value.pp x; + x) + (s i) + else s i + let get_registered (type a) exn db call d x = let exception Found of a in match @@ -105,11 +143,14 @@ let get_registered (type a) exn db call d x = exception CantInterpretTy of Ground.Ty.t let ty d ty = - get_registered - (fun ty -> raise (CantInterpretTy ty)) - Data.ty - (fun f d x -> f d x) - d ty + let seq = + get_registered + (fun ty -> raise (CantInterpretTy ty)) + Data.ty + (fun f d x -> f d x) + d ty + in + spy_sequence "ty" (SeqLim.incr_depth d seq) exception CantInterpretNode of Node.t @@ -117,20 +158,22 @@ let node d n = let parent = Node.H.create 5 in let rec aux d n' = if Node.H.mem parent n' then ( - Debug.dprintf4 debug "Interp.node recursive value: %a starting at %a" - Node.pp n' Node.pp n; - Egraph.contradiction d) + Debug.dprintf6 debug "Interp.node recursive value: %a starting at %a (%a)" + Node.pp n' Node.pp n + Fmt.(iter_bindings ~sep:comma Node.H.iter (pair Node.pp nop)) + parent; + raise Impossible) else ( Node.H.replace parent n' (); let r = get_registered - (fun n -> raise (CantInterpretNode n)) + (fun n' -> raise (CantInterpretNode n')) Data.node (fun f d x -> f aux d x) - d n + d n' in Node.H.remove parent n'; - r) + SeqLim.incr_depth d r) in aux d n @@ -251,7 +294,11 @@ module Fix_model = struct | CantInterpretNode _ when not (Ground.Ty.S.is_empty tys) -> ty d (Ground.Ty.S.choose tys) in - let seq = Sequence.of_list (Sequence.to_list seq) in + let seq = + seq |> spy_sequence "got" + |> (fun s -> s 0) + |> Sequence.to_list |> Sequence.of_list + in let seq = Sequence.map seq ~f:(fun v d -> Egraph.set_env d Data.env (During { r with nextnode }); diff --git a/src_colibri2/tests/solve/smt_adt/sat/dune.inc b/src_colibri2/tests/solve/smt_adt/sat/dune.inc index 3088caeac4bf5fef84dba86b6aa2150003519920..5c9df097ee6d07ceb7f346a751875c47e26d4ab2 100644 --- a/src_colibri2/tests/solve/smt_adt/sat/dune.inc +++ b/src_colibri2/tests/solve/smt_adt/sat/dune.inc @@ -5,3 +5,9 @@ (rule (alias runtest) (action (diff oracle list0.smt2.res))) (rule (action (with-stdout-to list1.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:list1.smt2})))) (rule (alias runtest) (action (diff oracle list1.smt2.res))) +(rule (action (with-stdout-to tree1.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree1.smt2})))) +(rule (alias runtest) (action (diff oracle tree1.smt2.res))) +(rule (action (with-stdout-to tree2.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree2.smt2})))) +(rule (alias runtest) (action (diff oracle tree2.smt2.res))) +(rule (action (with-stdout-to tree3.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree3.smt2})))) +(rule (alias runtest) (action (diff oracle tree3.smt2.res))) diff --git a/src_colibri2/tests/solve/smt_adt/sat/tree1.smt2 b/src_colibri2/tests/solve/smt_adt/sat/tree1.smt2 new file mode 100644 index 0000000000000000000000000000000000000000..0c4025d01601f669ca522c2d94e6ed22f99677d1 --- /dev/null +++ b/src_colibri2/tests/solve/smt_adt/sat/tree1.smt2 @@ -0,0 +1,13 @@ +;; produced by colibri.drv ;; +(set-logic ALL) + +(declare-datatype tree ( par (X) ( + ( leaf ) + ( node ( head X) ( left ( tree X )) ( right ( tree X )))))) + +(declare-fun a () (tree Int)) + +(assert ((_ is node) a)) +(assert ((_ is leaf) (left a))) + +(check-sat) diff --git a/src_colibri2/tests/solve/smt_adt/sat/tree2.smt2 b/src_colibri2/tests/solve/smt_adt/sat/tree2.smt2 new file mode 100644 index 0000000000000000000000000000000000000000..33d66c808e602d9d291b17af978c6b9f5f3f1fac --- /dev/null +++ b/src_colibri2/tests/solve/smt_adt/sat/tree2.smt2 @@ -0,0 +1,15 @@ +;; produced by colibri.drv ;; +(set-logic ALL) + +(declare-datatype tree ( par (X) ( + ( leaf ) + ( node ( head X) ( left ( tree X )) ( right ( tree X )))))) + +(declare-fun a () (tree Int)) + +(assert ((_ is node) a)) +(assert ((_ is node) (left a))) +(assert ((_ is node) (left (left a)))) +(assert ((_ is node) (left (left (left a))))) + +(check-sat) diff --git a/src_colibri2/tests/solve/smt_adt/sat/tree3.smt2 b/src_colibri2/tests/solve/smt_adt/sat/tree3.smt2 new file mode 100644 index 0000000000000000000000000000000000000000..b8e10b6024f2a9a653802403761edd65df4ea8e2 --- /dev/null +++ b/src_colibri2/tests/solve/smt_adt/sat/tree3.smt2 @@ -0,0 +1,18 @@ +;; produced by colibri.drv ;; +(set-logic ALL) + +(declare-datatype tree ( par (X) ( + ( leaf ) + ( node ( head X) ( left ( tree X )) ( right ( tree X )))))) + +(declare-fun a () (tree Int)) + +(assert ((_ is node) a)) +(assert ((_ is node) (left a))) +(assert ((_ is node) (left (left a)))) +(assert ((_ is leaf) (left (left (left a))))) +(assert ((_ is leaf) (left (left (left (left a)))))) +(assert ((_ is leaf) (left (left (left (left (left a))))))) +(assert ((_ is leaf) (left (left (left (left (left (left a)))))))) + +(check-sat) diff --git a/src_colibri2/theories/ADT/adt.ml b/src_colibri2/theories/ADT/adt.ml index b5dc6e7f4580ab0fee75127125cace3e77834237..c00d1838eb33e5b5fb9bb7ec2eab2ff9b4cc06ee 100644 --- a/src_colibri2/theories/ADT/adt.ml +++ b/src_colibri2/theories/ADT/adt.ml @@ -1,7 +1,10 @@ +open Colibri2_popop_lib open Colibri2_popop_lib.Popop_stdlib module Case = DInt module Field = DInt +let debug = Debug.register_info_flag "adt" ~desc:"Algebraic@ Datatype" + module D = struct type t = | Unk of Case.S.t @@ -59,12 +62,11 @@ end let () = DomKind.register (module D) let upd_dom d n d2 = - let s = - match Egraph.get_dom d D.key n with - | None -> d2 - | Some d1 -> D.merge' d d1 d2 - in - Egraph.set_dom d D.key n s + match Egraph.get_dom d D.key n with + | None -> Egraph.set_dom d D.key n d2 + | Some d1 -> + let d' = D.merge' d d1 d2 in + if not (D.equal d' d1) then Egraph.set_dom d D.key n d' let case_of_adt ty = match Ground.Ty.definition ty with @@ -205,9 +207,85 @@ let converter d (f : Ground.t) = Adt_value.propagate_value d f | _ -> () +let init_node d = + Interp.Register.node d (fun interp_node d n -> + match Egraph.get_dom d D.key n with + | None -> None + | Some adt -> ( + let sty = Ground.tys d n in + assert (Ground.Ty.S.is_num_elt 1 sty); + let ty = Ground.Ty.S.choose sty in + match ty with + | { + app = { builtin = Dolmen_std.Expr.Base; _ } as sym; + args = tyargs; + } -> ( + match Ground.Ty.definition sym with + | Abstract -> assert false + | Adt { ty; record = _; cases = all_cases } -> + Debug.dprintf4 debug "[ADT] node %a: %a" Node.pp n D.pp adt; + Some + (match adt with + | D.Unk cases -> + let seq = + Base.Sequence.of_list (Case.S.elements cases) + in + Adt_value.sequence_of_cases d ty tyargs all_cases seq + | D.One { case; fields } -> + let open Interp.SeqLim in + let { Expr.Ty.cstr; _ } = all_cases.(case) in + let subst = + Base.List.fold2_exn + ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) + ~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars tyargs + in + let args_ty = + Base.List.mapi + ~f:(fun i ty -> (i, Ground.Ty.convert subst ty)) + cstr.ty.fun_args + in + let fields = + Field.M.merge + (fun _ l r -> + match (l, r) with + | Some l, _ -> Some (`Node l) + | None, Some r -> Some (`Ty r) + | None, None -> assert false) + fields (Field.M.of_list args_ty) + in + let rec aux seq = function + | [] -> seq + | (i, `Node arg) :: args -> + let seq = + let+ l = seq + and* a = + Debug.dprintf4 debug "[ADT] interp_node %a:%a" + Node.pp n Node.pp arg; + interp_node d arg + in + Field.M.add i a l + in + aux seq args + | (i, `Ty arg) :: args -> + let seq = + let+ l = seq and* a = Interp.ty d arg in + Field.M.add i a l + in + aux seq args + in + let+ fields = + aux + (of_seq d (Base.Sequence.singleton Field.M.empty)) + (Field.M.bindings fields) + in + Adt_value.nodevalue + (Adt_value.index { tyargs; adt = ty; case; fields }))) + | _ -> assert false)) + let init env : unit = Adt_value.th_register env; Ground.register_converter env converter; - init_got_value_bool env + init_got_value_bool env; + init_node env let () = Egraph.add_default_theory init diff --git a/src_colibri2/theories/ADT/adt_value.ml b/src_colibri2/theories/ADT/adt_value.ml index 5b7b6cfb12a88496fb0dd47aea3b62443b948183..9882da027f603b11c35e1806ada56fa9525c5d4c 100644 --- a/src_colibri2/theories/ADT/adt_value.ml +++ b/src_colibri2/theories/ADT/adt_value.ml @@ -5,9 +5,16 @@ open Colibri2_popop_lib module Case = DInt module Field = DInt +type ts = { + adt : Expr.Ty.Const.t; + tyargs : Ground.Ty.t list; + case : Case.t; + fields : Value.t Field.M.t; +} + module T' = struct module T = struct - type t = { + type t = ts = { adt : Expr.Ty.Const.t; tyargs : Ground.Ty.t list; case : Case.t; @@ -30,16 +37,15 @@ module T' = struct end) end -module V = ValueKind.Register (T') -include T' +include ValueKind.Register (T') let interp d n = Opt.get_exn Impossible (Egraph.get_value d n) let compute d g = match Ground.sem g with | { app = { builtin = Expr.Tester { case; _ }; _ }; args = [ e ]; _ } -> - let v = V.coerce_nodevalue (interp d e) in - let v = V.value v in + let v = coerce_nodevalue (interp d e) in + let v = value v in `Some (Colibri2_theories_bool.Boolean.values_of_bool (Case.equal case v.case)) | { @@ -53,14 +59,14 @@ let compute d g = Field.M.add field (interp d a) acc) in let v = { tyargs; adt; case; fields } in - `Some (V.nodevalue (V.index v)) + `Some (nodevalue (index v)) | { app = { builtin = Expr.Destructor { case; field; _ }; _ }; args = [ e ]; _; } -> - let v = V.coerce_nodevalue (interp d e) in - let v = V.value v in + let v = coerce_nodevalue (interp d e) in + let v = value v in if Case.equal case v.case then let v = Field.M.find_opt field v.fields in match v with None -> raise Impossible | Some v -> `Some v @@ -88,6 +94,28 @@ let propagate_value d g = in Interp.TwoWatchLiteral.create d f g +let sequence_of_cases d ty tyargs all_cases cases = + let open Interp.SeqLim in + let* case = of_seq d cases in + let { Expr.Ty.cstr; _ } = all_cases.(case) in + let subst = + List.fold2_exn + ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) + ~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars tyargs + in + let args = List.map ~f:(Ground.Ty.convert subst) cstr.ty.fun_args in + let rec aux seq i = function + | [] -> seq + | ty :: args -> + let seq = + let+ l = seq and* a = Interp.ty d ty in + Field.M.add i a l + in + aux seq (i + 1) args + in + let+ fields = aux (of_seq d (Sequence.singleton Field.M.empty)) 0 args in + nodevalue (index { tyargs = args; adt = ty; case; fields }) + let init_ty d = Interp.Register.ty d (fun d ty -> match ty with @@ -95,36 +123,11 @@ let init_ty d = match Ground.Ty.definition sym with | Abstract -> None | Adt { ty; record = _; cases } -> - Some - (let open Interp.SeqLim in - let cases = - Sequence.unfold ~init:0 ~f:(fun i -> - if i < Array.length cases then Some ((i, cases.(i)), i + 1) - else None) - in - let* case, { cstr; _ } = of_seq d cases in - let subst = - List.fold2_exn - ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) - ~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars args - in - let args = - List.map ~f:(Ground.Ty.convert subst) cstr.ty.fun_args - in - let rec aux seq i = function - | [] -> seq - | ty :: args -> - let seq = - let+ l = seq and* a = Interp.ty d ty in - Field.M.add i a l - in - aux seq (i + 1) args - in - let+ fields = - aux (of_seq d (Sequence.singleton Field.M.empty)) 0 args - in - V.nodevalue (V.index { tyargs = args; adt = ty; case; fields })) - ) + let seq = + Sequence.unfold ~init:0 ~f:(fun i -> + if i < Array.length cases then Some (i, i + 1) else None) + in + Some (sequence_of_cases d ty args cases seq)) | _ -> None) let th_register d = diff --git a/src_colibri2/theories/ADT/adt_value.mli b/src_colibri2/theories/ADT/adt_value.mli index 3f2f3f0ce9ab6429787cebf0a302b1cd807726ce..7ebbb8b70b43ea56ddcba85d26960002ae30d950 100644 --- a/src_colibri2/theories/ADT/adt_value.mli +++ b/src_colibri2/theories/ADT/adt_value.mli @@ -1,3 +1,24 @@ +open Colibri2_popop_lib.Popop_stdlib +module Case = DInt +module Field = DInt + +type ts = { + adt : Expr.Ty.Const.t; + tyargs : Ground.Ty.t list; + case : Case.t; + fields : Value.t Field.M.t; +} + +include ValueKind.Registered with type s := ts + val th_register : Egraph.t -> unit val propagate_value : Egraph.t -> Ground.t -> unit + +val sequence_of_cases : + Egraph.t -> + Term.ty_const -> + Ground.All.ty list -> + Ty.adt_case array -> + int Base.Sequence.t -> + Value.t Interp.SeqLim.t