Skip to content
Snippets Groups Projects
Commit 8c10546a authored by François Bobot's avatar François Bobot
Browse files

Iterative deepening must limit the depth

 And fix generation of ADT
parent d08ce377
No related branches found
No related tags found
1 merge request!1Iterative deepening must limit the depth
Pipeline #34045 passed
...@@ -58,30 +58,59 @@ module Data = struct ...@@ -58,30 +58,59 @@ module Data = struct
end end
module SeqLim = struct module SeqLim = struct
type 'a t = 'a Sequence.t type 'a t = int -> 'a Sequence.t
let of_seq d s = let of_seq d s =
match Egraph.get_env d Data.env with match Egraph.get_env d Data.env with
| Before -> s | Before -> fun _ -> s
| During d -> | During d ->
Sequence.unfold_with s ~init:0 ~f:(fun i x -> fun j ->
if i <= d.limit then Sequence.Step.Yield (x, i + 1) if j >= d.limit then (
else ( d.limitreached := true;
d.limitreached := true; Sequence.take s 1)
Sequence.Step.Done)) else
Sequence.unfold_with s ~init:0 ~f:(fun i x ->
if i <= d.limit then Sequence.Step.Yield (x, i + 1)
else (
d.limitreached := true;
Sequence.Step.Done))
let map (x : 'a t) ~f i = Sequence.map ~f (x i)
let map = Sequence.map let unfold_with x ~init ~f i = Sequence.unfold_with (x i) ~init ~f
let unfold_with = Sequence.unfold_with let ( let+ ) x y i = Sequence.( >>| ) (x i) y
let ( let+ ) = Sequence.( >>| ) let ( let* ) t f i = Sequence.bind (t i) ~f:(fun x -> f x i)
let ( let* ) t f = Sequence.bind t ~f let ( and* ) x y i = Sequence.cartesian_product (x i) (y i)
let ( and* ) = Sequence.cartesian_product let incr_depth d x i =
match Egraph.get_env d Data.env with
| Before -> x i
| During d ->
if i >= d.limit then (
d.limitreached := true;
Sequence.take (x (i + 1)) 1)
else x (i + 1)
end end
module Register = struct module Register : sig
val check : Egraph.t -> (Egraph.t -> Ground.t -> bool option) -> unit
val node :
Egraph.t ->
((Egraph.t -> Nodes.Node.t -> Nodes.Value.t SeqLim.t) ->
Egraph.t ->
Nodes.Node.t ->
Nodes.Value.t SeqLim.t option) ->
unit
val ty :
Egraph.t ->
(Egraph.t -> Ground.Ty.t -> Nodes.Value.t SeqLim.t option) ->
unit
end = struct
open Data open Data
let check d f = Datastructure.Queue.push check d f let check d f = Datastructure.Queue.push check d f
...@@ -91,6 +120,15 @@ module Register = struct ...@@ -91,6 +120,15 @@ module Register = struct
let ty d f = Datastructure.Queue.push ty d f let ty d f = Datastructure.Queue.push ty d f
end end
let spy_sequence msg s i =
if false then
Sequence.mapi
~f:(fun i x ->
Fmt.epr "@.%s %i: %a@." msg i Value.pp x;
x)
(s i)
else s i
let get_registered (type a) exn db call d x = let get_registered (type a) exn db call d x =
let exception Found of a in let exception Found of a in
match match
...@@ -105,11 +143,14 @@ let get_registered (type a) exn db call d x = ...@@ -105,11 +143,14 @@ let get_registered (type a) exn db call d x =
exception CantInterpretTy of Ground.Ty.t exception CantInterpretTy of Ground.Ty.t
let ty d ty = let ty d ty =
get_registered let seq =
(fun ty -> raise (CantInterpretTy ty)) get_registered
Data.ty (fun ty -> raise (CantInterpretTy ty))
(fun f d x -> f d x) Data.ty
d ty (fun f d x -> f d x)
d ty
in
spy_sequence "ty" (SeqLim.incr_depth d seq)
exception CantInterpretNode of Node.t exception CantInterpretNode of Node.t
...@@ -117,20 +158,22 @@ let node d n = ...@@ -117,20 +158,22 @@ let node d n =
let parent = Node.H.create 5 in let parent = Node.H.create 5 in
let rec aux d n' = let rec aux d n' =
if Node.H.mem parent n' then ( if Node.H.mem parent n' then (
Debug.dprintf4 debug "Interp.node recursive value: %a starting at %a" Debug.dprintf6 debug "Interp.node recursive value: %a starting at %a (%a)"
Node.pp n' Node.pp n; Node.pp n' Node.pp n
Egraph.contradiction d) Fmt.(iter_bindings ~sep:comma Node.H.iter (pair Node.pp nop))
parent;
raise Impossible)
else ( else (
Node.H.replace parent n' (); Node.H.replace parent n' ();
let r = let r =
get_registered get_registered
(fun n -> raise (CantInterpretNode n)) (fun n' -> raise (CantInterpretNode n'))
Data.node Data.node
(fun f d x -> f aux d x) (fun f d x -> f aux d x)
d n d n'
in in
Node.H.remove parent n'; Node.H.remove parent n';
r) SeqLim.incr_depth d r)
in in
aux d n aux d n
...@@ -251,7 +294,11 @@ module Fix_model = struct ...@@ -251,7 +294,11 @@ module Fix_model = struct
| CantInterpretNode _ when not (Ground.Ty.S.is_empty tys) -> | CantInterpretNode _ when not (Ground.Ty.S.is_empty tys) ->
ty d (Ground.Ty.S.choose tys) ty d (Ground.Ty.S.choose tys)
in in
let seq = Sequence.of_list (Sequence.to_list seq) in let seq =
seq |> spy_sequence "got"
|> (fun s -> s 0)
|> Sequence.to_list |> Sequence.of_list
in
let seq = let seq =
Sequence.map seq ~f:(fun v d -> Sequence.map seq ~f:(fun v d ->
Egraph.set_env d Data.env (During { r with nextnode }); Egraph.set_env d Data.env (During { r with nextnode });
......
...@@ -5,3 +5,9 @@ ...@@ -5,3 +5,9 @@
(rule (alias runtest) (action (diff oracle list0.smt2.res))) (rule (alias runtest) (action (diff oracle list0.smt2.res)))
(rule (action (with-stdout-to list1.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:list1.smt2})))) (rule (action (with-stdout-to list1.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:list1.smt2}))))
(rule (alias runtest) (action (diff oracle list1.smt2.res))) (rule (alias runtest) (action (diff oracle list1.smt2.res)))
(rule (action (with-stdout-to tree1.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree1.smt2}))))
(rule (alias runtest) (action (diff oracle tree1.smt2.res)))
(rule (action (with-stdout-to tree2.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree2.smt2}))))
(rule (alias runtest) (action (diff oracle tree2.smt2.res)))
(rule (action (with-stdout-to tree3.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:tree3.smt2}))))
(rule (alias runtest) (action (diff oracle tree3.smt2.res)))
;; produced by colibri.drv ;;
(set-logic ALL)
(declare-datatype tree ( par (X) (
( leaf )
( node ( head X) ( left ( tree X )) ( right ( tree X ))))))
(declare-fun a () (tree Int))
(assert ((_ is node) a))
(assert ((_ is leaf) (left a)))
(check-sat)
;; produced by colibri.drv ;;
(set-logic ALL)
(declare-datatype tree ( par (X) (
( leaf )
( node ( head X) ( left ( tree X )) ( right ( tree X ))))))
(declare-fun a () (tree Int))
(assert ((_ is node) a))
(assert ((_ is node) (left a)))
(assert ((_ is node) (left (left a))))
(assert ((_ is node) (left (left (left a)))))
(check-sat)
;; produced by colibri.drv ;;
(set-logic ALL)
(declare-datatype tree ( par (X) (
( leaf )
( node ( head X) ( left ( tree X )) ( right ( tree X ))))))
(declare-fun a () (tree Int))
(assert ((_ is node) a))
(assert ((_ is node) (left a)))
(assert ((_ is node) (left (left a))))
(assert ((_ is leaf) (left (left (left a)))))
(assert ((_ is leaf) (left (left (left (left a))))))
(assert ((_ is leaf) (left (left (left (left (left a)))))))
(assert ((_ is leaf) (left (left (left (left (left (left a))))))))
(check-sat)
open Colibri2_popop_lib
open Colibri2_popop_lib.Popop_stdlib open Colibri2_popop_lib.Popop_stdlib
module Case = DInt module Case = DInt
module Field = DInt module Field = DInt
let debug = Debug.register_info_flag "adt" ~desc:"Algebraic@ Datatype"
module D = struct module D = struct
type t = type t =
| Unk of Case.S.t | Unk of Case.S.t
...@@ -59,12 +62,11 @@ end ...@@ -59,12 +62,11 @@ end
let () = DomKind.register (module D) let () = DomKind.register (module D)
let upd_dom d n d2 = let upd_dom d n d2 =
let s = match Egraph.get_dom d D.key n with
match Egraph.get_dom d D.key n with | None -> Egraph.set_dom d D.key n d2
| None -> d2 | Some d1 ->
| Some d1 -> D.merge' d d1 d2 let d' = D.merge' d d1 d2 in
in if not (D.equal d' d1) then Egraph.set_dom d D.key n d'
Egraph.set_dom d D.key n s
let case_of_adt ty = let case_of_adt ty =
match Ground.Ty.definition ty with match Ground.Ty.definition ty with
...@@ -205,9 +207,85 @@ let converter d (f : Ground.t) = ...@@ -205,9 +207,85 @@ let converter d (f : Ground.t) =
Adt_value.propagate_value d f Adt_value.propagate_value d f
| _ -> () | _ -> ()
let init_node d =
Interp.Register.node d (fun interp_node d n ->
match Egraph.get_dom d D.key n with
| None -> None
| Some adt -> (
let sty = Ground.tys d n in
assert (Ground.Ty.S.is_num_elt 1 sty);
let ty = Ground.Ty.S.choose sty in
match ty with
| {
app = { builtin = Dolmen_std.Expr.Base; _ } as sym;
args = tyargs;
} -> (
match Ground.Ty.definition sym with
| Abstract -> assert false
| Adt { ty; record = _; cases = all_cases } ->
Debug.dprintf4 debug "[ADT] node %a: %a" Node.pp n D.pp adt;
Some
(match adt with
| D.Unk cases ->
let seq =
Base.Sequence.of_list (Case.S.elements cases)
in
Adt_value.sequence_of_cases d ty tyargs all_cases seq
| D.One { case; fields } ->
let open Interp.SeqLim in
let { Expr.Ty.cstr; _ } = all_cases.(case) in
let subst =
Base.List.fold2_exn
~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc)
~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars tyargs
in
let args_ty =
Base.List.mapi
~f:(fun i ty -> (i, Ground.Ty.convert subst ty))
cstr.ty.fun_args
in
let fields =
Field.M.merge
(fun _ l r ->
match (l, r) with
| Some l, _ -> Some (`Node l)
| None, Some r -> Some (`Ty r)
| None, None -> assert false)
fields (Field.M.of_list args_ty)
in
let rec aux seq = function
| [] -> seq
| (i, `Node arg) :: args ->
let seq =
let+ l = seq
and* a =
Debug.dprintf4 debug "[ADT] interp_node %a:%a"
Node.pp n Node.pp arg;
interp_node d arg
in
Field.M.add i a l
in
aux seq args
| (i, `Ty arg) :: args ->
let seq =
let+ l = seq and* a = Interp.ty d arg in
Field.M.add i a l
in
aux seq args
in
let+ fields =
aux
(of_seq d (Base.Sequence.singleton Field.M.empty))
(Field.M.bindings fields)
in
Adt_value.nodevalue
(Adt_value.index { tyargs; adt = ty; case; fields })))
| _ -> assert false))
let init env : unit = let init env : unit =
Adt_value.th_register env; Adt_value.th_register env;
Ground.register_converter env converter; Ground.register_converter env converter;
init_got_value_bool env init_got_value_bool env;
init_node env
let () = Egraph.add_default_theory init let () = Egraph.add_default_theory init
...@@ -5,9 +5,16 @@ open Colibri2_popop_lib ...@@ -5,9 +5,16 @@ open Colibri2_popop_lib
module Case = DInt module Case = DInt
module Field = DInt module Field = DInt
type ts = {
adt : Expr.Ty.Const.t;
tyargs : Ground.Ty.t list;
case : Case.t;
fields : Value.t Field.M.t;
}
module T' = struct module T' = struct
module T = struct module T = struct
type t = { type t = ts = {
adt : Expr.Ty.Const.t; adt : Expr.Ty.Const.t;
tyargs : Ground.Ty.t list; tyargs : Ground.Ty.t list;
case : Case.t; case : Case.t;
...@@ -30,16 +37,15 @@ module T' = struct ...@@ -30,16 +37,15 @@ module T' = struct
end) end)
end end
module V = ValueKind.Register (T') include ValueKind.Register (T')
include T'
let interp d n = Opt.get_exn Impossible (Egraph.get_value d n) let interp d n = Opt.get_exn Impossible (Egraph.get_value d n)
let compute d g = let compute d g =
match Ground.sem g with match Ground.sem g with
| { app = { builtin = Expr.Tester { case; _ }; _ }; args = [ e ]; _ } -> | { app = { builtin = Expr.Tester { case; _ }; _ }; args = [ e ]; _ } ->
let v = V.coerce_nodevalue (interp d e) in let v = coerce_nodevalue (interp d e) in
let v = V.value v in let v = value v in
`Some `Some
(Colibri2_theories_bool.Boolean.values_of_bool (Case.equal case v.case)) (Colibri2_theories_bool.Boolean.values_of_bool (Case.equal case v.case))
| { | {
...@@ -53,14 +59,14 @@ let compute d g = ...@@ -53,14 +59,14 @@ let compute d g =
Field.M.add field (interp d a) acc) Field.M.add field (interp d a) acc)
in in
let v = { tyargs; adt; case; fields } in let v = { tyargs; adt; case; fields } in
`Some (V.nodevalue (V.index v)) `Some (nodevalue (index v))
| { | {
app = { builtin = Expr.Destructor { case; field; _ }; _ }; app = { builtin = Expr.Destructor { case; field; _ }; _ };
args = [ e ]; args = [ e ];
_; _;
} -> } ->
let v = V.coerce_nodevalue (interp d e) in let v = coerce_nodevalue (interp d e) in
let v = V.value v in let v = value v in
if Case.equal case v.case then if Case.equal case v.case then
let v = Field.M.find_opt field v.fields in let v = Field.M.find_opt field v.fields in
match v with None -> raise Impossible | Some v -> `Some v match v with None -> raise Impossible | Some v -> `Some v
...@@ -88,6 +94,28 @@ let propagate_value d g = ...@@ -88,6 +94,28 @@ let propagate_value d g =
in in
Interp.TwoWatchLiteral.create d f g Interp.TwoWatchLiteral.create d f g
let sequence_of_cases d ty tyargs all_cases cases =
let open Interp.SeqLim in
let* case = of_seq d cases in
let { Expr.Ty.cstr; _ } = all_cases.(case) in
let subst =
List.fold2_exn
~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc)
~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars tyargs
in
let args = List.map ~f:(Ground.Ty.convert subst) cstr.ty.fun_args in
let rec aux seq i = function
| [] -> seq
| ty :: args ->
let seq =
let+ l = seq and* a = Interp.ty d ty in
Field.M.add i a l
in
aux seq (i + 1) args
in
let+ fields = aux (of_seq d (Sequence.singleton Field.M.empty)) 0 args in
nodevalue (index { tyargs = args; adt = ty; case; fields })
let init_ty d = let init_ty d =
Interp.Register.ty d (fun d ty -> Interp.Register.ty d (fun d ty ->
match ty with match ty with
...@@ -95,36 +123,11 @@ let init_ty d = ...@@ -95,36 +123,11 @@ let init_ty d =
match Ground.Ty.definition sym with match Ground.Ty.definition sym with
| Abstract -> None | Abstract -> None
| Adt { ty; record = _; cases } -> | Adt { ty; record = _; cases } ->
Some let seq =
(let open Interp.SeqLim in Sequence.unfold ~init:0 ~f:(fun i ->
let cases = if i < Array.length cases then Some (i, i + 1) else None)
Sequence.unfold ~init:0 ~f:(fun i -> in
if i < Array.length cases then Some ((i, cases.(i)), i + 1) Some (sequence_of_cases d ty args cases seq))
else None)
in
let* case, { cstr; _ } = of_seq d cases in
let subst =
List.fold2_exn
~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc)
~init:Expr.Ty.Var.M.empty cstr.ty.fun_vars args
in
let args =
List.map ~f:(Ground.Ty.convert subst) cstr.ty.fun_args
in
let rec aux seq i = function
| [] -> seq
| ty :: args ->
let seq =
let+ l = seq and* a = Interp.ty d ty in
Field.M.add i a l
in
aux seq (i + 1) args
in
let+ fields =
aux (of_seq d (Sequence.singleton Field.M.empty)) 0 args
in
V.nodevalue (V.index { tyargs = args; adt = ty; case; fields }))
)
| _ -> None) | _ -> None)
let th_register d = let th_register d =
......
open Colibri2_popop_lib.Popop_stdlib
module Case = DInt
module Field = DInt
type ts = {
adt : Expr.Ty.Const.t;
tyargs : Ground.Ty.t list;
case : Case.t;
fields : Value.t Field.M.t;
}
include ValueKind.Registered with type s := ts
val th_register : Egraph.t -> unit val th_register : Egraph.t -> unit
val propagate_value : Egraph.t -> Ground.t -> unit val propagate_value : Egraph.t -> Ground.t -> unit
val sequence_of_cases :
Egraph.t ->
Term.ty_const ->
Ground.All.ty list ->
Ty.adt_case array ->
int Base.Sequence.t ->
Value.t Interp.SeqLim.t
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment