From 6a72a13caa08ea99a8012407aa5716db2477478c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Mon, 7 Jun 2021 13:41:44 +0200 Subject: [PATCH] [ADT] Use the monad --- src_colibri2/core/colibri2_core.ml | 14 +- src_colibri2/popop_lib/IArray.ml | 21 +- src_colibri2/popop_lib/IArray.mli | 4 +- src_colibri2/theories/ADT/adt.ml | 392 +++++++++------------- src_colibri2/theories/ADT/adt_value.ml | 101 ++++-- src_colibri2/theories/ADT/adt_value.mli | 30 +- src_colibri2/theories/LRA/dom_interval.ml | 3 +- 7 files changed, 278 insertions(+), 287 deletions(-) diff --git a/src_colibri2/core/colibri2_core.ml b/src_colibri2/core/colibri2_core.ml index 68462f235..021442de6 100644 --- a/src_colibri2/core/colibri2_core.ml +++ b/src_colibri2/core/colibri2_core.ml @@ -152,12 +152,12 @@ module Dom = struct val key : t Kind.t - val inter : t -> t -> t option - (** [inter d1 d2] compute the intersection of + val inter : Egraph.t -> t -> t -> t option + (** [inter d d1 d2] compute the intersection of [d1] and [d2] return [None] if it is empty *) - val is_singleton : t -> Value.t option - (** [is_singleton d] if [d] is + val is_singleton : Egraph.t -> t -> Value.t option + (** [is_singleton _ d] if [d] is restricted to a singleton return the corresponding value *) end @@ -173,7 +173,7 @@ module Dom = struct let set_dom d node v = Egraph.set_dom d L.key node v; - match L.is_singleton v with + match L.is_singleton d v with | Some cst -> Egraph.set_value d node cst | None -> () @@ -187,7 +187,7 @@ module Dom = struct assert (not (Egraph.is_equal d cl1 cl2)); match (i1, cl1, i2, cl2) with | Some i1, _, Some i2, _ -> ( - match L.inter i1 i2 with + match L.inter d i1 i2 with | None -> Colibri2_popop_lib.Debug.dprintf10 Egraph.print_contradiction "[%a] The intersection of %a and %a is empty when\n\ @@ -204,7 +204,7 @@ module Dom = struct match Egraph.get_dom d L.key n' with | None -> set_dom d n' v' | Some old -> ( - match L.inter old v' with + match L.inter d old v' with | None -> Colibri2_popop_lib.Debug.dprintf8 Egraph.print_contradiction "[%a] The intersection of %a with %a is empty when updating \ diff --git a/src_colibri2/popop_lib/IArray.ml b/src_colibri2/popop_lib/IArray.ml index b81cd0603..a745db6d1 100644 --- a/src_colibri2/popop_lib/IArray.ml +++ b/src_colibri2/popop_lib/IArray.ml @@ -39,14 +39,19 @@ let of_list_map ~f = function fill 1 l; a -let of_array = Array.copy -let of_iter l (iter : ('a -> unit) -> unit) = - if l = 0 then empty else - let res = Array.make l (Obj.magic 0 : 'a) in - let r = ref 0 in - iter (fun v -> res.(!r) <- v; incr r); - assert (!r == l); - res +let of_array a = + if Array.length a = 0 then empty + else Array.copy a + +let of_array_map ~f a = + let l = Array.length a in + if l = 0 then empty + else + let r = Array.make l (f a.(0)) in + for i=1 to (l-1) do + r.(i) <- (f a.(i)) + done; + r let to_list = Array.to_list let to_seq = Base.Array.to_sequence diff --git a/src_colibri2/popop_lib/IArray.mli b/src_colibri2/popop_lib/IArray.mli index cacc44434..5b48718fc 100644 --- a/src_colibri2/popop_lib/IArray.mli +++ b/src_colibri2/popop_lib/IArray.mli @@ -27,9 +27,7 @@ type 'a t val of_list: 'a list -> 'a t val of_list_map: f:('a -> 'b) -> 'a list -> 'b t val of_array: 'a array -> 'a t -val of_iter: int -> (('a -> unit) -> unit) -> 'a t -(** create the array using an iterator. The integer indicate the - number of iteration that will occur *) +val of_array_map: f:('a -> 'b) -> 'a array -> 'b t val empty: 'a t val is_empty: 'a t -> bool diff --git a/src_colibri2/theories/ADT/adt.ml b/src_colibri2/theories/ADT/adt.ml index 71a6e05cf..f7b48140a 100644 --- a/src_colibri2/theories/ADT/adt.ml +++ b/src_colibri2/theories/ADT/adt.ml @@ -27,12 +27,16 @@ let debug = Debug.register_info_flag "adt" ~desc:"Algebraic@ Datatype" module D = struct type t = - | Unk of Case.S.t - | One of { case : Case.t; fields : Node.t Field.M.t } + | Unk of { adt : Adt_value.MonoAdt.t; cases : Case.S.t } + | One of { + adt : Adt_value.MonoAdt.t; + case : Case.t; + fields : Node.t Field.M.t; + } [@@deriving eq] let pp fmt = function - | Unk s -> Fmt.pf fmt "{%a}" Case.S.pp s + | Unk { cases; adt = _ } -> Fmt.pf fmt "{%a}" Case.S.pp cases | One c -> Fmt.pf fmt "%i(%a)" c.case (Field.M.pp Node.pp) c.fields let key = @@ -43,82 +47,79 @@ module D = struct let name = "adt" end) - let merged d1 d2 = Option.equal equal d1 d2 - - let merge' info d d1 d2 = + let inter d d1 d2 = match (d1, d2) with - | Unk s1, Unk s2 -> + | Unk { cases = s1; adt = adt1 }, Unk { cases = s2; adt = adt2 } -> + assert (Adt_value.MonoAdt.equal adt1 adt2); let s = Case.S.inter s1 s2 in - if Case.S.is_empty s then ( - Debug.dprintf5 Egraph.print_contradiction - "[ADT] The intersection of %a and %a is empty %t" Field.S.pp s1 - Case.S.pp s2 info; - Egraph.contradiction d) - else Unk s - | One { case = c1; fields = f1 }, One { case = c2; fields = f2 } -> + if Case.S.is_empty s then None else Some (Unk { cases = s; adt = adt1 }) + | ( One { case = c1; fields = f1; adt = adt1 }, + One { case = c2; fields = f2; adt = adt2 } ) -> + assert (Adt_value.MonoAdt.equal adt1 adt2); if Case.equal c1 c2 then - One - { - case = c1; - fields = - Field.M.union - (fun _ n1 n2 -> - Egraph.merge d n1 n2; - Some n1) - f1 f2; - } - else ( - Debug.dprintf5 Egraph.print_contradiction - "[ADT] case conflict between %a and %a %t" Case.pp c1 Case.pp c2 - info; - Egraph.contradiction d) - | (One { case = c1; _ } as d1), Unk c2 - | Unk c2, (One { case = c1; _ } as d1) -> - if Case.S.mem c1 c2 then d1 - else ( - Debug.dprintf5 Egraph.print_contradiction - "[ADT] case %a is not in %a %t" Case.pp c1 Case.S.pp c2 info; - Egraph.contradiction d) - - let merge d (d1, n1) (d2, n2) _ = - let s = - match (d1, d2) with - | Some d1, Some d2 -> - merge' - (fun fmt -> - Fmt.pf fmt "when merging %a and %a" Node.pp n1 Node.pp n2) - d d1 d2 - | None, Some d1 | Some d1, None -> d1 - | None, None -> assert false - (* absurd: already merged *) - in - Egraph.set_dom d key n1 s; - Egraph.set_dom d key n2 s + Some + (One + { + adt = adt1; + case = c1; + fields = + Field.M.union + (fun _ n1 n2 -> + Egraph.merge d n1 n2; + Some n1) + f1 f2; + }) + else None + | (One { case = c1; adt = adt1; _ } as d1), Unk { adt = adt2; cases = c2 } + | Unk { cases = c2; adt = adt1 }, (One { case = c1; adt = adt2; _ } as d1) + -> + assert (Adt_value.MonoAdt.equal adt1 adt2); + if Case.S.mem c1 c2 then Some d1 else None + + let is_singleton d = function + | Unk s -> + if Case.S.is_num_elt 1 s.cases then + let case = Case.S.choose s.cases in + match Adt_value.MonoAdt.case s.adt case with + | [] -> + Some + (Adt_value.nodevalue @@ Adt_value.index + @@ Adt_value.{ adt = s.adt; case; fields = Field.M.empty }) + | _ -> None + else None + | One { case; fields; adt } -> + if + List.length (Adt_value.MonoAdt.case adt case) + = Field.M.cardinal fields + then + try + let fields = + Field.M.map + (fun f -> + match Egraph.get_value d f with + | None -> raise Exit + | Some v -> v) + fields + in + Some + (Adt_value.nodevalue @@ Adt_value.index + @@ Adt_value.{ adt; case; fields }) + with Exit -> None + else None end -let () = Dom.register (module D) - -let upd_dom d n d2 = - match Egraph.get_dom d D.key n with - | None -> Egraph.set_dom d D.key n d2 - | Some d1 -> - let d' = - D.merge' (fun fmt -> Fmt.pf fmt "when updating %a" Node.pp n) d d1 d2 - in - if not (D.equal d' d1) then Egraph.set_dom d D.key n d' +include Dom.Lattice (D) -let case_of_adt ty = - match Ground.Ty.definition ty with - | Expr.Ty.Abstract -> assert false (* absurd: must be an adt *) - | Expr.Ty.Adt { cases; _ } -> - Case.S.of_list (List.init (Array.length cases) (fun i -> i)) +let case_of_adt adt = + Case.S.of_list + (List.init (IArray.length adt.Adt_value.MonoAdt.cases) (fun i -> i)) (** Decide for destructors *) module Decide = struct - let make_decision n s d = + let make_decision n adt cases d = Colibri2_popop_lib.Debug.dprintf4 Egraph.print_decision - "[ADT] decide %a on %a" Case.S.pp s Node.pp n; - upd_dom d n (Unk s) + "[ADT] decide %a on %a" Case.S.pp cases Node.pp n; + upd_dom d n (Unk { adt; cases }) let new_decision n adt c = { @@ -128,125 +129,90 @@ module Decide = struct let decisions s = Egraph.DecTodo [ - make_decision n (Case.S.singleton c); - make_decision n (Case.S.remove c s); + make_decision n adt (Case.S.singleton c); + make_decision n adt (Case.S.remove c s); ] in match Egraph.get_dom d D.key n with | Some (One _) -> Egraph.DecNo - | Some (Unk s) when Case.S.is_num_elt 1 s -> Egraph.DecNo - | Some (Unk s) when not (Case.S.mem c s) -> Egraph.DecNo - | Some (Unk s) -> decisions s + | Some (Unk s) when Case.S.is_num_elt 1 s.cases -> Egraph.DecNo + | Some (Unk s) when not (Case.S.mem c s.cases) -> Egraph.DecNo + | Some (Unk s) -> decisions s.cases | None -> decisions (case_of_adt adt)); } end -module Monad : sig - type 'a monad = Egraph.t -> Node.t -> 'a - - val get : Node.t -> D.t option monad - - val set : Node.t -> D.t option monad -> unit monad - - val getb : Node.t -> bool option monad - - val setb : Node.t -> bool option option monad -> unit monad - - val unit : (Egraph.t -> unit) option monad -> unit monad - - val ( let+ ) : 'b option monad -> ('b -> 'c) -> 'c option monad -end = struct - type 'a monad = Egraph.t -> Node.t -> 'a - - let[@ocaml.always inline] get_dom dom n' d _ = Egraph.get_dom d dom n' - - let[@ocaml.always inline] set n' v' d n = - if Node.equal n n' then () - else Option.iter (fun v' -> upd_dom d n' v') (v' d n) - - let[@ocaml.always inline] get a = get_dom D.key a - - let[@ocaml.always inline] get_value key n' d _ = - Option.bind (Egraph.get_value d n') (Value.value key) - - let[@ocaml.always inline] set_value (type b) - (value : (module Value.S with type s = b)) n' (v' : b option option monad) - : unit monad = - fun d n -> - if Node.equal n n' then () - else - let module V = (val value) in - Option.iter - (Option.iter (fun v' -> - Egraph.set_value d n' (V.nodevalue (V.index v')))) - (v' d n) - - let getb a = get_value Boolean.dom a - - let setb a = set_value (module Boolean.BoolValue) a - - let unit (v' : (Egraph.t -> unit) option monad) : unit monad = - fun d n -> match v' d n with None -> () | Some f -> f d - - let[@ocaml.always inline] ( let+ ) a (f : 'b -> 'a) d n = Option.map f (a d n) -end - (** {2 Initialization} *) let converter d (f : Ground.t) = let r = Ground.node f in let reg n = Egraph.register d n in let open Monad in + let setb = setv Boolean.dom in + let getb = getv Boolean.dom in + let set = updd upd_dom in + let get = getd key in match Ground.sem f with - | { app = { builtin = Expr.Tester { adt; case; _ }; _ }; args; _ } -> + | { app = { builtin = Expr.Tester { adt; case; _ }; _ }; args; tyargs; _ } -> + let adt = Option.get_exn (Adt_value.MonoAdt.index adt tyargs) in let e = IArray.extract1_exn args in reg e; - Daemon.attach_value d r Boolean.BoolValue.key (fun d n _ -> - (set e - (let+ vr = getb r in - if vr then D.Unk (Case.S.singleton case) - else - let s = Case.S.remove case (case_of_adt adt) in - if Case.S.is_empty s then ( - Debug.dprintf4 Egraph.print_contradiction - "[ADT] tester %a removed the only case %a of the type" - Node.pp r Case.pp case; - Egraph.contradiction d) - else D.Unk s)) - d n); - Daemon.attach_dom d e D.key + attach d + (set e + (let+ vr = getb r in + if vr then D.Unk { adt; cases = Case.S.singleton case } + else + let cases = Case.S.remove case (case_of_adt adt) in + if Case.S.is_empty cases then ( + Debug.dprintf4 Egraph.print_contradiction + "[ADT] tester %a removed the only case %a of the type" Node.pp + r Case.pp case; + Egraph.contradiction d) + else D.Unk { adt; cases })); + attach d (setb r - (let+ ve = get e in + (let* ve = get e in match ve with - | Unk s -> if Case.S.mem case s then None else Some false + | Unk s -> if Case.S.mem case s.cases then None else Some false | One { case = case'; _ } -> Some (Case.equal case case'))); Adt_value.propagate_value d f - | { app = { builtin = Expr.Constructor { case; _ }; _ }; args; _ } -> + | { + app = { builtin = Expr.Constructor { case; adt; _ }; _ }; + args; + tyargs; + _; + } -> + let adt = Option.get_exn (Adt_value.MonoAdt.index adt tyargs) in IArray.iter ~f:reg args; let fields = IArray.foldi ~init:Field.M.empty args ~f:(fun field acc a -> Field.M.add field a acc) in - upd_dom d r (One { case; fields }); + upd_dom d r (One { adt; case; fields }); Adt_value.propagate_value d f - | { app = { builtin = Expr.Destructor { case; field; adt; _ }; _ }; args; _ } - -> + | { + app = { builtin = Expr.Destructor { case; field; adt; _ }; _ }; + args; + tyargs; + _; + } -> + let adt = Option.get_exn (Adt_value.MonoAdt.index adt tyargs) in let e = IArray.extract1_exn args in reg e; Egraph.register_decision d (Decide.new_decision e adt case); - let upd d = - upd_dom d e (One { case; fields = Field.M.singleton field r }) - in - let nothing _ = () in - Daemon.attach_dom d e D.key - (unit - (let+ ve = get e in + attach d + (set e + (let* ve = get e in match ve with - | Unk s when Case.S.is_num_elt 1 s -> - if Case.S.mem case s then upd else nothing - | Unk _ -> nothing + | Unk s when Case.S.is_num_elt 1 s.cases -> + if Case.S.mem case s.cases then + Some (D.One { case; fields = Field.M.singleton field r; adt }) + else None + | Unk _ -> None | One { case = case'; _ } -> - if Case.equal case case' then upd else nothing)); + if Case.equal case case' then + Some (D.One { case; fields = Field.M.singleton field r; adt }) + else None)); Adt_value.propagate_value d f | _ -> () @@ -254,76 +220,52 @@ let init_node d = Interp.Register.node d (fun interp_node d n -> match Egraph.get_dom d D.key n with | None -> None - | Some adt -> ( - let sty = Ground.tys d n in - assert (Ground.Ty.S.is_num_elt 1 sty); - let ty = Ground.Ty.S.choose sty in - match ty with - | { app = { builtin = Expr.Base; _ } as sym; args = tyargs } -> ( - match Ground.Ty.definition sym with - | Abstract -> assert false - | Adt { ty; record = _; cases = all_cases } -> - Debug.dprintf4 debug "[ADT] node %a: %a" Node.pp n D.pp adt; - Some - (match adt with - | D.Unk cases -> - let seq = - Base.Sequence.of_list (Case.S.elements cases) - in - Adt_value.sequence_of_cases d ty tyargs all_cases seq - | D.One { case; fields } -> - let open Interp.SeqLim in - let { Expr.Ty.cstr; _ } = all_cases.(case) in - let fun_vars, fun_args, _ = - Expr.Ty.poly_sig cstr.id_ty - in - let subst = - Base.List.fold2_exn - ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) - ~init:Expr.Ty.Var.M.empty fun_vars tyargs - in - let args_ty = - Base.List.mapi - ~f:(fun i ty -> (i, Ground.Ty.convert subst ty)) - fun_args - in - let fields = - Field.M.merge - (fun _ l r -> - match (l, r) with - | Some l, _ -> Some (`Node l) - | None, Some r -> Some (`Ty r) - | None, None -> assert false) - fields (Field.M.of_list args_ty) - in - let rec aux seq = function - | [] -> seq - | (i, `Node arg) :: args -> - let seq = - let+ l = seq - and* a = - Debug.dprintf4 debug "[ADT] interp_node %a:%a" - Node.pp n Node.pp arg; - interp_node d arg - in - Field.M.add i a l - in - aux seq args - | (i, `Ty arg) :: args -> - let seq = - let+ l = seq and* a = Interp.ty d arg in - Field.M.add i a l - in - aux seq args - in - let+ fields = - aux - (of_seq d (Base.Sequence.singleton Field.M.empty)) - (Field.M.bindings fields) + | Some adt -> + Debug.dprintf4 debug "[ADT] node %a: %a" Node.pp n D.pp adt; + Some + (match adt with + | D.Unk { adt; cases } -> + let seq = Base.Sequence.of_list (Case.S.elements cases) in + Adt_value.sequence_of_cases d adt seq + | D.One { adt; case; fields } -> + let open Interp.SeqLim in + let args_ty = Adt_value.MonoAdt.case adt case in + let fields = + Field.M.merge + (fun _ l r -> + match (l, r) with + | Some l, _ -> Some (`Node l) + | None, Some r -> Some (`Ty r) + | None, None -> assert false) + fields + (Field.M.of_list (List.mapi (fun i x -> (i, x)) args_ty)) + in + let rec aux seq = function + | [] -> seq + | (i, `Node arg) :: args -> + let seq = + let+ l = seq + and* a = + Debug.dprintf4 debug "[ADT] interp_node %a:%a" Node.pp + n Node.pp arg; + interp_node d arg in - Adt_value.nodevalue - (Adt_value.index { tyargs; adt = ty; case; fields }))) - | _ -> assert false)) + Field.M.add i a l + in + aux seq args + | (i, `Ty arg) :: args -> + let seq = + let+ l = seq and* a = Interp.ty d arg in + Field.M.add i a l + in + aux seq args + in + let+ fields = + aux + (of_seq d (Base.Sequence.singleton Field.M.empty)) + (Field.M.bindings fields) + in + Adt_value.nodevalue (Adt_value.index { adt; case; fields }))) let init env : unit = Adt_value.th_register env; diff --git a/src_colibri2/theories/ADT/adt_value.ml b/src_colibri2/theories/ADT/adt_value.ml index 7ff8d2357..cc7157d8d 100644 --- a/src_colibri2/theories/ADT/adt_value.ml +++ b/src_colibri2/theories/ADT/adt_value.ml @@ -25,21 +25,65 @@ open Colibri2_popop_lib module Case = DInt module Field = DInt -type ts = { - adt : Expr.Ty.Const.t; - tyargs : Ground.Ty.t list; - case : Case.t; - fields : Value.t Field.M.t; -} +(** Monomorph version *) +module MonoAdt : sig + type mono = { adt : Expr.Ty.Const.t; tyargs : Ground.Ty.t list } + + and t = private { mono : mono; cases : Ground.Ty.t list IArray.t } + [@@deriving eq, ord, hash] + + val index : Expr.Ty.Const.t -> Ground.Ty.t list -> t option + + val case : t -> int -> Ground.Ty.t list + (** Return the type of the argument of the given cases *) +end = struct + module Mono = struct + module Key = struct + type t = { adt : Expr.Ty.Const.t; tyargs : Ground.Ty.t list } + [@@deriving eq, ord, hash, show] + end + + include Key + include MkDatatype (Key) + end + + type mono = Mono.t = { adt : Expr.Ty.Const.t; tyargs : Ground.Ty.t list } + [@@deriving eq, ord, hash] + + type t = { mono : mono; cases : Ground.Ty.t list IArray.t } + [@@deriving ord, hash] + + let equal = phys_equal + + let index = + let h = Mono.H.create 10 in + fun adt tyargs -> + Mono.H.memo + (fun mono -> + match Ground.Ty.definition adt with + | Abstract -> None + | Adt { cases; _ } -> + let map { Expr.Ty.cstr; _ } = + let fun_vars, fun_args, _ = Expr.Ty.poly_sig cstr.id_ty in + let subst = + List.fold2_exn + ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) + ~init:Expr.Ty.Var.M.empty fun_vars tyargs + in + List.map ~f:(Ground.Ty.convert subst) fun_args + in + let cases = IArray.of_array_map ~f:map cases in + Some { mono; cases }) + h { adt; tyargs } + + let case t i = IArray.get t.cases i +end + +type ts = { adt : MonoAdt.t; case : Case.t; fields : Value.t Field.M.t } module T' = struct module T = struct - type t = ts = { - adt : Expr.Ty.Const.t; - tyargs : Ground.Ty.t list; - case : Case.t; - fields : Value.t Field.M.t; - } + type t = ts = { adt : MonoAdt.t; case : Case.t; fields : Value.t Field.M.t } [@@deriving eq, ord, hash] let pp fmt c = Fmt.pf fmt "ADT.%i(%a)" c.case (Field.M.pp Value.pp) c.fields @@ -68,13 +112,16 @@ let compute d g = args; tyargs; _; - } -> + } -> ( let fields = IArray.foldi ~init:Field.M.empty args ~f:(fun field acc a -> Field.M.add field (interp d a) acc) in - let v = { tyargs; adt; case; fields } in - `Some (nodevalue (index v)) + match MonoAdt.index adt tyargs with + | None -> raise Impossible + | Some adt -> + let v = { adt; case; fields } in + `Some (nodevalue (index v))) | { app = { builtin = Expr.Destructor { case; field; _ }; _ }; args; _ } -> let e = IArray.extract1_exn args in let v = coerce_nodevalue (interp d e) in @@ -106,17 +153,10 @@ let propagate_value d g = in Interp.WatchArgs.create d f g -let sequence_of_cases d ty tyargs all_cases cases = +let sequence_of_cases d t cases = let open Interp.SeqLim in let* case = of_seq d cases in - let { Expr.Ty.cstr; _ } = all_cases.(case) in - let fun_vars, fun_args, _ = Expr.Ty.poly_sig cstr.id_ty in - let subst = - List.fold2_exn - ~f:(fun acc v ty -> Expr.Ty.Var.M.add v ty acc) - ~init:Expr.Ty.Var.M.empty fun_vars tyargs - in - let args = List.map ~f:(Ground.Ty.convert subst) fun_args in + let args = MonoAdt.case t case in let rec aux seq i = function | [] -> seq | ty :: args -> @@ -127,20 +167,21 @@ let sequence_of_cases d ty tyargs all_cases cases = aux seq (i + 1) args in let+ fields = aux (of_seq d (Sequence.singleton Field.M.empty)) 0 args in - nodevalue (index { tyargs = args; adt = ty; case; fields }) + nodevalue (index { adt = t; case; fields }) let init_ty d = Interp.Register.ty d (fun d ty -> match ty with | { app = { builtin = Expr.Base; _ } as sym; args } -> ( - match Ground.Ty.definition sym with - | Abstract -> None - | Adt { ty; record = _; cases } -> + match MonoAdt.index sym args with + | None -> None + | Some adt -> let seq = Sequence.unfold ~init:0 ~f:(fun i -> - if i < Array.length cases then Some (i, i + 1) else None) + if i < IArray.length adt.cases then Some (i, i + 1) + else None) in - Some (sequence_of_cases d ty args cases seq)) + Some (sequence_of_cases d adt seq)) | _ -> None) let th_register d = diff --git a/src_colibri2/theories/ADT/adt_value.mli b/src_colibri2/theories/ADT/adt_value.mli index 7b88e25c6..b1c4fe287 100644 --- a/src_colibri2/theories/ADT/adt_value.mli +++ b/src_colibri2/theories/ADT/adt_value.mli @@ -18,16 +18,25 @@ (* for more details (enclosed in the file licenses/LGPLv2.1). *) (*************************************************************************) -open Colibri2_popop_lib.Popop_stdlib +open Colibri2_core +open Colibri2_popop_lib +open Popop_stdlib module Case = DInt module Field = DInt -type ts = { - adt : Expr.Ty.Const.t; - tyargs : Ground.Ty.t list; - case : Case.t; - fields : Value.t Field.M.t; -} +module MonoAdt : sig + type mono = { adt : Expr.Ty.Const.t; tyargs : Ground.Ty.t list } + + and t = private { mono : mono; cases : Ground.Ty.t list IArray.t } + [@@deriving eq, ord, hash] + + val index : Expr.Ty.Const.t -> Ground.Ty.t list -> t option + + val case : t -> int -> Ground.Ty.t list + (** Return the type of the argument of the given cases *) +end + +type ts = { adt : MonoAdt.t; case : Case.t; fields : Value.t Field.M.t } include Value.S with type s := ts @@ -36,9 +45,4 @@ val th_register : Egraph.t -> unit val propagate_value : Egraph.t -> Ground.t -> unit val sequence_of_cases : - Egraph.t -> - Expr.Term.ty_const -> - Ground.All.ty list -> - Expr.Ty.adt_case array -> - int Base.Sequence.t -> - Value.t Interp.SeqLim.t + Egraph.t -> MonoAdt.t -> int Base.Sequence.t -> Value.t Interp.SeqLim.t diff --git a/src_colibri2/theories/LRA/dom_interval.ml b/src_colibri2/theories/LRA/dom_interval.ml index 8b41ae7f3..a727f38c2 100644 --- a/src_colibri2/theories/LRA/dom_interval.ml +++ b/src_colibri2/theories/LRA/dom_interval.ml @@ -33,7 +33,8 @@ let dom = Dom.Kind.create (module struct type t = D.t let name = "ARITH" end) include (Dom.Lattice(struct include D let key = dom - let is_singleton d = + let inter _ d1 d2 = inter d1 d2 + let is_singleton _ d = Option.map (fun x -> RealValue.nodevalue @@ RealValue.index x) (is_singleton d) end )) -- GitLab