From 39b9a5f31c9ce8a1e1a22236413ac72eb4f9c50b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Tue, 5 Jan 2021 13:56:34 +0100
Subject: [PATCH] [Egraph] Fix UF heuristic

 - handle all the lasteffort of the same time at the same time
 - Handle coercion in Arith
---
 src_colibri2/core/egraph.ml                   | 14 +++---
 src_colibri2/popop_lib/TimeWheel.ml           | 19 +++++--
 src_colibri2/popop_lib/TimeWheel.mli          |  4 +-
 src_colibri2/solver/scheduler.ml              | 44 ++++++++++------
 src_colibri2/stdlib/context.mli               | 16 +-----
 src_colibri2/tests/solve/smt_lra/sat/dune.inc |  4 ++
 .../tests/solve/smt_lra/sat/to_real.smt2      |  5 ++
 .../tests/solve/smt_lra/sat/to_real2.smt2     |  8 +++
 .../tests/solve/smt_lra/unsat/dune.inc        |  4 ++
 .../tests/solve/smt_lra/unsat/to_real.smt2    |  5 ++
 .../tests/solve/smt_lra/unsat/to_real2.smt2   |  8 +++
 src_colibri2/theories/LRA/realValue.ml        |  7 +++
 .../theories/quantifier/quantifier.ml         | 50 +++++++++++++------
 13 files changed, 134 insertions(+), 54 deletions(-)
 create mode 100644 src_colibri2/tests/solve/smt_lra/sat/to_real.smt2
 create mode 100644 src_colibri2/tests/solve/smt_lra/sat/to_real2.smt2
 create mode 100644 src_colibri2/tests/solve/smt_lra/unsat/to_real.smt2
 create mode 100644 src_colibri2/tests/solve/smt_lra/unsat/to_real2.smt2

diff --git a/src_colibri2/core/egraph.ml b/src_colibri2/core/egraph.ml
index c8dd819b3..73a6dae44 100644
--- a/src_colibri2/core/egraph.ml
+++ b/src_colibri2/core/egraph.ml
@@ -24,7 +24,7 @@ open Nodes
 
 exception Contradiction
 
-let debug = Debug.register_info_flag
+let debug = Debug.register_flag
     ~desc:"for the core solver"
     "Egraph.all"
 let debug_few = Debug.register_info_flag
@@ -588,6 +588,8 @@ module Delayed = struct
     Debug.register_flag
       ~desc:"Accept to use value as representative"
       "choose_repr_no_value"
+
+  (** representative is the second returned elements *)
   let choose_repr t ((_,a) as pa) ((_,b) as pb) =
     let heuristic () =
       if Shuffle.is_shuffle () then
@@ -597,16 +599,16 @@ module Delayed = struct
         let rb = Node.M.find_def 0 b t.rang in
         if ra = rb then begin
           t.rang <- Node.M.add a (ra+1) t.rang;
-          (pa,pb)
+          (pb,pa)
         end else
-        if ra < rb then (pb,pa)
-        else (pa,pb)
+        if ra < rb then (pa,pb)
+        else (pb,pa)
     in
     if Debug.test_noflag flag_choose_repr_no_value then
       let va = Nodes.Only_for_solver.is_value a in
       let vb = Nodes.Only_for_solver.is_value b in
-      if va && not vb then (pb,pa)
-      else if va && not vb then (pa,pb)
+      if va && not vb then (pa,pb)
+      else if va && not vb then (pb,pa)
       else heuristic ()
     else
       heuristic ()
diff --git a/src_colibri2/popop_lib/TimeWheel.ml b/src_colibri2/popop_lib/TimeWheel.ml
index c4c300150..6f606a1aa 100644
--- a/src_colibri2/popop_lib/TimeWheel.ml
+++ b/src_colibri2/popop_lib/TimeWheel.ml
@@ -14,11 +14,13 @@ module type S = sig
   (** [add t v offset] add the event v at the given offset in the futur *)
 
   val next: 'a t -> 'a option
-
+  val next_at_same_time: 'a t -> 'a option
+  val find_next: 'a t -> unit
 
   val current_time: 'a t -> int
 
   val size: 'a t -> int
+  val size_at_current_time: 'a t -> int
 end
 
 module Make
@@ -149,8 +151,7 @@ module Make
       in
       aux t
 
-  let next t =
-    find_next t;
+  let next_at_same_time t =
     match Array.get t.futur 0 with
     | Nil -> None
     | Cons { v; time; next } ->
@@ -158,6 +159,18 @@ module Make
       Ref.set t.size (Ref.get t.size - 1);
       Array.set t.futur 0 next;
       Some v
+
+  let next t =
+    find_next t;
+    next_at_same_time t
+
+  let size_at_current_time t =
+    let rec aux acc = function
+      | Nil -> acc
+      | Cons {next; _ } ->
+        aux (acc+1) next
+    in
+    aux 0 (Array.get t.futur 0)
 end
 
 include (Make
diff --git a/src_colibri2/popop_lib/TimeWheel.mli b/src_colibri2/popop_lib/TimeWheel.mli
index 9ea639639..77a4af617 100644
--- a/src_colibri2/popop_lib/TimeWheel.mli
+++ b/src_colibri2/popop_lib/TimeWheel.mli
@@ -18,10 +18,12 @@ module type S = sig
   (** [add t v offset] add the event v at the given offset in the futur *)
 
   val next: 'a t -> 'a option
-
+  val next_at_same_time: 'a t -> 'a option
+  val find_next: 'a t -> unit
 
   val current_time: 'a t -> int
   val size: 'a t -> int
+  val size_at_current_time: 'a t -> int
 
 end
 
diff --git a/src_colibri2/solver/scheduler.ml b/src_colibri2/solver/scheduler.ml
index 69cb1a8d5..0ace67127 100644
--- a/src_colibri2/solver/scheduler.ml
+++ b/src_colibri2/solver/scheduler.ml
@@ -293,21 +293,35 @@ let run_one_step t d =
         conflict_analysis t
     end
   | None ->
-    match Prio.extract_min t.choices with
-    | Att.Decision (_,chogen), prio ->
-      Debug.dprintf0 debug "[Scheduler] decision";
-      try_run_dec t d prio chogen
-    | exception Leftistheap.Empty ->
-      match Context.TimeWheel.next t.lasteffort with
-      | Some att -> begin
-          Debug.incr stats_lasteffort;
-          try
-            Egraph.Backtrackable.run_daemon d att; d
-          with Egraph.Contradiction ->
-            Debug.dprintf0 debug "[Scheduler] Contradiction";
-            conflict_analysis t
-        end
-      | None -> d
+    match Context.TimeWheel.next_at_same_time t.lasteffort with
+    | Some att -> begin
+        Debug.dprintf1 debug "[Scheduler] LastEffort mode remaining: %t"
+          (fun fmt -> Fmt.int fmt (Context.TimeWheel.size_at_current_time t.lasteffort));
+        Debug.incr stats_lasteffort;
+        try
+          Egraph.Backtrackable.run_daemon d att; d
+        with Egraph.Contradiction ->
+          Debug.dprintf0 debug "[Scheduler] Contradiction";
+          conflict_analysis t
+      end
+    | None ->
+      match Prio.extract_min t.choices with
+      | Att.Decision (_,chogen), prio ->
+        Debug.dprintf0 debug "[Scheduler] decision";
+        try_run_dec t d prio chogen
+      | exception Leftistheap.Empty ->
+        Debug.dprintf1 debug "[Scheduler] LastEffort: %i" (Context.TimeWheel.size t.lasteffort);
+        match Context.TimeWheel.next t.lasteffort with
+        | Some att -> begin
+            Debug.incr stats_lasteffort;
+            try
+              Egraph.Backtrackable.run_daemon d att; d
+            with Egraph.Contradiction ->
+              Debug.dprintf0 debug "[Scheduler] Contradiction";
+              conflict_analysis t
+          end
+        | None ->
+          d
 
 let rec flush t d =
   try
diff --git a/src_colibri2/stdlib/context.mli b/src_colibri2/stdlib/context.mli
index 02dbc240e..40bf9a974 100644
--- a/src_colibri2/stdlib/context.mli
+++ b/src_colibri2/stdlib/context.mli
@@ -187,19 +187,5 @@ module Array: sig
   val set: 'a t -> int -> 'a -> unit
   val get: 'a t -> int -> 'a
 end
-module TimeWheel: sig
-  type 'a t
-
-  val create: creator -> 'a t
-
-
-  val add: 'a t -> 'a -> int -> unit
-  (** [add t v offset] add the event v at the given offset in the futur *)
 
-  val next: 'a t -> 'a option
-
-
-  val current_time: 'a t -> int
-
-  val size: 'a t -> int
-end
+module TimeWheel: Colibri2_popop_lib.TimeWheel.S with type context := creator
diff --git a/src_colibri2/tests/solve/smt_lra/sat/dune.inc b/src_colibri2/tests/solve/smt_lra/sat/dune.inc
index 72cd4ae45..e08ed1fcd 100644
--- a/src_colibri2/tests/solve/smt_lra/sat/dune.inc
+++ b/src_colibri2/tests/solve/smt_lra/sat/dune.inc
@@ -57,3 +57,7 @@
 (rule (alias runtest) (action (diff oracle solver_set_pending_merge_expsameexp.smt2.res)))
 (rule (action (with-stdout-to solver_subst_eventdom_find.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:solver_subst_eventdom_find.smt2}))))
 (rule (alias runtest) (action (diff oracle solver_subst_eventdom_find.smt2.res)))
+(rule (action (with-stdout-to to_real.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:to_real.smt2}))))
+(rule (alias runtest) (action (diff oracle to_real.smt2.res)))
+(rule (action (with-stdout-to to_real2.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:to_real2.smt2}))))
+(rule (alias runtest) (action (diff oracle to_real2.smt2.res)))
diff --git a/src_colibri2/tests/solve/smt_lra/sat/to_real.smt2 b/src_colibri2/tests/solve/smt_lra/sat/to_real.smt2
new file mode 100644
index 000000000..c50290fbb
--- /dev/null
+++ b/src_colibri2/tests/solve/smt_lra/sat/to_real.smt2
@@ -0,0 +1,5 @@
+(set-logic ALL)
+
+(assert (= (to_real 0) 0.0))
+
+(check-sat)
diff --git a/src_colibri2/tests/solve/smt_lra/sat/to_real2.smt2 b/src_colibri2/tests/solve/smt_lra/sat/to_real2.smt2
new file mode 100644
index 000000000..9a9087837
--- /dev/null
+++ b/src_colibri2/tests/solve/smt_lra/sat/to_real2.smt2
@@ -0,0 +1,8 @@
+(set-logic ALL)
+
+(declare-fun a () Int)
+(declare-fun b () Int)
+
+(assert (= (to_real (+ a b)) (+ (to_real a) (to_real b))))
+
+(check-sat)
diff --git a/src_colibri2/tests/solve/smt_lra/unsat/dune.inc b/src_colibri2/tests/solve/smt_lra/unsat/dune.inc
index 99881a34e..5969faf65 100644
--- a/src_colibri2/tests/solve/smt_lra/unsat/dune.inc
+++ b/src_colibri2/tests/solve/smt_lra/unsat/dune.inc
@@ -7,3 +7,7 @@
 (rule (alias runtest) (action (diff oracle solver_merge_itself_repr_empty.smt2.res)))
 (rule (action (with-stdout-to solver_set_sem_merge_sign.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:solver_set_sem_merge_sign.smt2}))))
 (rule (alias runtest) (action (diff oracle solver_set_sem_merge_sign.smt2.res)))
+(rule (action (with-stdout-to to_real.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:to_real.smt2}))))
+(rule (alias runtest) (action (diff oracle to_real.smt2.res)))
+(rule (action (with-stdout-to to_real2.smt2.res (run %{bin:colibri2} --max-step 1000 %{dep:to_real2.smt2}))))
+(rule (alias runtest) (action (diff oracle to_real2.smt2.res)))
diff --git a/src_colibri2/tests/solve/smt_lra/unsat/to_real.smt2 b/src_colibri2/tests/solve/smt_lra/unsat/to_real.smt2
new file mode 100644
index 000000000..2cd05c008
--- /dev/null
+++ b/src_colibri2/tests/solve/smt_lra/unsat/to_real.smt2
@@ -0,0 +1,5 @@
+(set-logic ALL)
+
+(assert (not (= (to_real 0) 0.0)))
+
+(check-sat)
diff --git a/src_colibri2/tests/solve/smt_lra/unsat/to_real2.smt2 b/src_colibri2/tests/solve/smt_lra/unsat/to_real2.smt2
new file mode 100644
index 000000000..095dd8cdd
--- /dev/null
+++ b/src_colibri2/tests/solve/smt_lra/unsat/to_real2.smt2
@@ -0,0 +1,8 @@
+(set-logic ALL)
+
+(declare-fun a () Int)
+(declare-fun b () Int)
+
+(assert (not (= (to_real (+ a b)) (+ (to_real a) (to_real b)))))
+
+(check-sat)
diff --git a/src_colibri2/theories/LRA/realValue.ml b/src_colibri2/theories/LRA/realValue.ml
index 3bc4f065b..e4c38c4ac 100644
--- a/src_colibri2/theories/LRA/realValue.ml
+++ b/src_colibri2/theories/LRA/realValue.ml
@@ -96,6 +96,13 @@ let converter d (f:Ground.t) =
     List.iter (fun n -> wait_for_this_node_get_a_value d n wait) [a;b]
   in
   match Ground.sem f with
+  | { app = {builtin = Expr.Coercion};
+      tyargs =
+        ( [{app={builtin = (Expr.Int|Expr.Rat);_};_};{app={builtin = Expr.Real;_};_}]
+        | [{app={builtin = Expr.Int;_};_};{app={builtin = Expr.Rat;_};_}]
+        );
+    args = [a] } ->
+    merge a
   | { app = {builtin = Expr.Integer s}; tyargs = []; args = []; _ } ->
     merge (cst (Q.of_string s))
   | { app = {builtin = Expr.Decimal s}; tyargs = []; args = []; _ } ->
diff --git a/src_colibri2/theories/quantifier/quantifier.ml b/src_colibri2/theories/quantifier/quantifier.ml
index 081775add..7f178cd27 100644
--- a/src_colibri2/theories/quantifier/quantifier.ml
+++ b/src_colibri2/theories/quantifier/quantifier.ml
@@ -11,6 +11,9 @@ end
 let debug =
   Debug.register_info_flag ~desc:"Handling of quantifiers" "quantifiers"
 
+let debug_full =
+  Debug.register_flag ~desc:"Handling of quantifiers full" "quantifiers.full"
+
 let nb_instantiation = Debug.register_stats_int ~name:"instantiation" ~init:0
 
 let nb_new_instantiation =
@@ -139,6 +142,7 @@ module Pattern = struct
   let rec of_term (t : Expr.Term.t) =
     match t.descr with
     | Var v -> Var v
+    | App ({ builtin = Expr.Coercion; _ }, _, [ a ]) -> of_term a
     | App (f, tys, tl) -> App (f, tys, List.map of_term tl)
     | _ -> (* absurd *) assert false
 
@@ -172,8 +176,8 @@ module Pattern = struct
       match p with
       | Var v ->
           let match_ty (acc : Ground.Subst.S.t) ty : Ground.Subst.S.t =
-            Debug.dprintf4 debug "[Quant] match_ty %a %a" Node.pp n Ground.Ty.pp
-              ty;
+            Debug.dprintf4 debug "[Quant] match_ty %a %a" Expr.Ty.pp v.ty
+              Ground.Ty.pp ty;
             Ground.Subst.S.union (match_ty substs ty v.ty) acc
           in
           let substs =
@@ -529,10 +533,6 @@ module Trigger = struct
 
   let instantiate_aux d tri subst =
     let form = Ground.ThClosedQuantifier.sem tri.form in
-    Debug.dprintf8 debug "[Quant] %a instantiation found %a, pat %a, checks:%a"
-      Ground.Subst.pp subst Expr.Term.pp form.body Pattern.pp tri.pat
-      Fmt.(list ~sep:comma Pattern.pp)
-      tri.checks;
     Debug.incr nb_instantiation;
     let n = Ground.convert ~subst form.body in
     if Debug.test_flag Debug.stats && not (Egraph.is_registered d n) then
@@ -564,6 +564,11 @@ module Trigger = struct
     let print_runable = pp_runable
 
     let run d (tri, subst) =
+      Debug.dprintf8 debug
+        "[Quant] %a delayed instantiation %a, pat %a, checks:%a" Ground.Subst.pp
+        subst Ground.ThClosedQuantifier.pp tri.form Pattern.pp tri.pat
+        Fmt.(list ~sep:comma Pattern.pp)
+        tri.checks;
       instantiate_aux d tri subst;
       None
   end
@@ -571,6 +576,18 @@ module Trigger = struct
   let () = Egraph.Wait.register_dem (module Delayed_instantiation)
 
   let instantiate d tri subst =
+    let subst =
+      {
+        subst with
+        Ground.Subst.term =
+          Expr.Term.Var.M.map (Egraph.find_def d) subst.Ground.Subst.term;
+      }
+    in
+    Debug.dprintf8 debug "[Quant] %a instantiation found %a, pat %a, checks:%a"
+      Ground.Subst.pp subst Ground.ThClosedQuantifier.pp tri.form Pattern.pp
+      tri.pat
+      Fmt.(list ~sep:comma Pattern.pp)
+      tri.checks;
     if
       tri.eager
       && List.for_all
@@ -579,7 +596,9 @@ module Trigger = struct
              with Not_found -> false)
            tri.checks
     then instantiate_aux d tri subst
-    else Egraph.register_delayed_event d Delayed_instantiation.key (tri, subst)
+    else (
+      Debug.dprintf0 debug "[Quant] Delayed";
+      Egraph.register_delayed_event d Delayed_instantiation.key (tri, subst) )
 
   let match_ d tri n =
     Debug.dprintf4 debug "[Quant] match %a %a" pp tri Node.pp n;
@@ -616,7 +635,7 @@ module InvertedPath = struct
       { matches = Pattern.M.empty; triggers = []; ups = F_Pos.M.empty }
 
   let rec exec d acc substs n ip =
-    Debug.dprintf5 debug "[Quant] Exec: %a, %a[%i]@." Node.pp n
+    Debug.dprintf5 debug_full "[Quant] Exec: %a, %a[%i]" Node.pp n
       Ground.Subst.S.pp substs
       (Ground.Subst.S.cardinal substs);
     (* pp ip; *)
@@ -632,7 +651,7 @@ module InvertedPath = struct
       let acc =
         Pattern.M.fold_left
           (fun acc p ip ->
-            Debug.dprintf2 debug "[Quant] Exec match %a@." Pattern.pp p;
+            Debug.dprintf2 debug_full "[Quant] Exec match %a" Pattern.pp p;
             let substs = Pattern.match_term d substs n p in
             exec d acc substs n ip)
           acc ip.matches
@@ -663,7 +682,7 @@ module InvertedPath = struct
               Ground.S.fold_left (match_one_app pt) acc parents
             in
             let forall_fpos acc p ptl =
-              Debug.dprintf2 debug "[Quant] Exec ups %a@." F_Pos.pp p;
+              Debug.dprintf2 debug_full "[Quant] Exec ups %a@." F_Pos.pp p;
               let parents = F_Pos.M.find_def Ground.S.empty p info.parents in
               List.fold_left (forall_triplets parents) acc ptl
             in
@@ -777,7 +796,8 @@ let find_new_event d n (info : Info.t) (info' : Info.t) =
            *     let n = Ground.node n in
            *     InvertedPath.exec d acc n ip)
            *   acc parents *)
-          Debug.dprintf4 debug "[Quant] PP %a found for %a" PP.pp pp Node.pp n;
+          Debug.dprintf4 debug_full "[Quant] PP %a found for %a" PP.pp pp
+            Node.pp n;
           InvertedPath.exec d acc Pattern.init n ip
       | _ -> acc
     in
@@ -795,7 +815,8 @@ let find_new_event d n (info : Info.t) (info' : Info.t) =
            *     let n = Ground.node n in
            *     InvertedPath.exec d acc n ip)
            *   acc parents *)
-          Debug.dprintf4 debug "[Quant] PC %a found for %a" PC.pp pc Node.pp n;
+          Debug.dprintf4 debug_full "[Quant] PC %a found for %a" PC.pp pc
+            Node.pp n;
           InvertedPath.exec d acc Pattern.init n ip
       | _ -> acc
     in
@@ -826,8 +847,8 @@ let find_new_event d n (info : Info.t) (info' : Info.t) =
               in
               if Ground.Subst.S.subset substs2 substs1 then acc
               else (
-                Debug.dprintf6 debug "[Quant] PT %a %a found for %a" F_Pos.pp
-                  f_pos Expr.Ty.pp vty Node.pp n;
+                Debug.dprintf6 debug_full "[Quant] PT %a %a found for %a"
+                  F_Pos.pp f_pos Expr.Ty.pp vty Node.pp n;
                 InvertedPath.exec d acc Pattern.init n ip ))
             acc q
       | None -> acc
@@ -897,6 +918,7 @@ let init, attach =
 let quantifier_registered d th = attach d (Ground.ThClosedQuantifier.node th) th
 
 let add_info_on_ground_terms d thg =
+  Debug.dprintf2 debug "[Quant] add_info_on_ground_terms %a" Ground.pp thg;
   let res = Ground.node thg in
   let g = Ground.sem thg in
   let merge_info n info =
-- 
GitLab