From e7595b2c63f72f9794c90188997940abaf9d2183 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Thu, 12 Jan 2023 18:29:55 +0100
Subject: [PATCH] [Definition] add limited eager folding

---
 colibri2/stdlib/context.ml                    |  22 +-
 colibri2/stdlib/context.mli                   |   3 +
 colibri2/tests/solve/all/unsat/dune.inc       |   3 +
 .../all/unsat/mjrty-Mjrty-mjrtyqtvc_3.psmt2   | 384 ++++++++++++++++++
 colibri2/theories/quantifier/definitions.ml   |  11 +-
 5 files changed, 413 insertions(+), 10 deletions(-)
 create mode 100644 colibri2/tests/solve/all/unsat/mjrty-Mjrty-mjrtyqtvc_3.psmt2

diff --git a/colibri2/stdlib/context.ml b/colibri2/stdlib/context.ml
index 951ebaf73..5fc1f7d6b 100644
--- a/colibri2/stdlib/context.ml
+++ b/colibri2/stdlib/context.ml
@@ -70,28 +70,34 @@ module Ref = struct
     match r.previous with
     | [] -> ()
     | { at } :: _ when at.alive -> ()
-    | _ ->
+    | { at = _; value } :: l ->
         let rec aux v = function
           | { at; value } :: l when not at.alive -> aux value l
           | l ->
               r.contents <- v;
               r.previous <- l
         in
-        aux r.contents r.previous
+        aux value l
+
+  let set_aux (r : _ t) v =
+    match r.previous with
+    | { at } :: _ when bp_equal at (bp r.context) -> r.contents <- v
+    | _ ->
+        r.previous <- { at = bp r.context; value = r.contents } :: r.previous;
+        r.contents <- v
 
   let set r v =
     rewind r;
-    if not (CCEqual.physical r.contents v) then
-      match r.previous with
-      | { at } :: _ when bp_equal at (bp r.context) -> r.contents <- v
-      | _ ->
-          r.previous <- { at = bp r.context; value = r.contents } :: r.previous;
-          r.contents <- v
+    if not (CCEqual.physical r.contents v) then set_aux r v
 
   let get r =
     rewind r;
     r.contents
 
+  let incr r =
+    rewind r;
+    set_aux r (1 + r.contents)
+
   let creator (h : 'a t) = h.context
   let pp pp fmt r = pp fmt (get r)
 end
diff --git a/colibri2/stdlib/context.mli b/colibri2/stdlib/context.mli
index d787183ee..480887e44 100644
--- a/colibri2/stdlib/context.mli
+++ b/colibri2/stdlib/context.mli
@@ -64,6 +64,9 @@ module Ref : sig
   val get : 'a t -> 'a
   (** Get the current value of the reference *)
 
+  val incr : int t -> unit
+  (** Increment the current value of the reference *)
+
   val creator : 'a t -> creator
   val pp : 'a Fmt.t -> 'a t Fmt.t
 end
diff --git a/colibri2/tests/solve/all/unsat/dune.inc b/colibri2/tests/solve/all/unsat/dune.inc
index 584dd91c3..257cb483b 100644
--- a/colibri2/tests/solve/all/unsat/dune.inc
+++ b/colibri2/tests/solve/all/unsat/dune.inc
@@ -20,6 +20,9 @@
 --dont-print-result %{dep:lost_in_search_union.psmt2})) (package colibri2))
 (rule (alias runtest-learning) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat --learning --dont-print-result %{dep:lost_in_search_union.psmt2})) (package colibri2))
 (rule (alias runtest) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat 
+--dont-print-result %{dep:mjrty-Mjrty-mjrtyqtvc_3.psmt2})) (package colibri2))
+(rule (alias runtest-learning) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat --learning --dont-print-result %{dep:mjrty-Mjrty-mjrtyqtvc_3.psmt2})) (package colibri2))
+(rule (alias runtest) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat 
 --dont-print-result %{dep:mul_abs.smt2})) (package colibri2))
 (rule (alias runtest-learning) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat --learning --dont-print-result %{dep:mul_abs.smt2})) (package colibri2))
 (rule (alias runtest) (action (run %{bin:colibri2} --size=50M --time=60s --max-steps 3500 --check-status unsat 
diff --git a/colibri2/tests/solve/all/unsat/mjrty-Mjrty-mjrtyqtvc_3.psmt2 b/colibri2/tests/solve/all/unsat/mjrty-Mjrty-mjrtyqtvc_3.psmt2
new file mode 100644
index 000000000..63bc4c0bf
--- /dev/null
+++ b/colibri2/tests/solve/all/unsat/mjrty-Mjrty-mjrtyqtvc_3.psmt2
@@ -0,0 +1,384 @@
+;; produced by local colibri2.drv ;;
+(set-logic ALL)
+(set-info :smt-lib-version 2.6)
+;;; SMT-LIB2: integer arithmetic
+(declare-sort string 0)
+
+(declare-datatypes ((tuple0 0))
+  (((Tuple0))))
+
+(declare-datatypes ((ref 1))
+  ((par (a) ((refqtmk (contents a))))))
+
+;; "prefix !"
+(define-fun prefix_ex (par (a1)
+  ((r (ref a1))) a1
+  (contents r)))
+
+(declare-sort infix_mngt 2)
+
+;; "infix @"
+(declare-fun infix_at (par (a2
+  b)
+  ((infix_mngt a2
+  b)
+  a2) b))
+
+;; "get"
+(define-fun get (par (a1
+  b1)
+  ((f (infix_mngt a1 b1)) (x a1)) b1
+  (infix_at f x)))
+
+;; "set"
+(declare-fun set (par (a1
+  b1)
+  ((infix_mngt a1
+  b1)
+  a1
+  b1) (infix_mngt a1
+  b1)))
+
+;; "set'def"
+(assert (par (a1 b1)
+  (forall ((f (infix_mngt a1 b1)) (x a1) (v b1) (y a1))
+    (= (infix_at (set f x v) y) (ite (= y x) v (infix_at f y))))))
+
+;; "mixfix []"
+(define-fun mixfix_lbrb (par (a1
+  b1)
+  ((f (infix_mngt a1 b1)) (x a1)) b1
+  (infix_at f x)))
+
+;; "mixfix [<-]"
+(define-fun mixfix_lblsmnrb (par (a1
+  b1)
+  ((f (infix_mngt a1 b1)) (x a1) (v b1)) (infix_mngt a1 b1)
+  (set f x v)))
+
+(declare-sort array 1)
+
+;; "elts"
+(declare-fun elts (par (a1)
+  ((array a1)) (infix_mngt Int
+  a1)))
+
+;; "length"
+(declare-fun length (par (a1)
+  ((array a1)) Int))
+
+;; "array'invariant"
+(assert (par (a1)
+  (forall ((self (array a1)))
+    (! (<= 0 (length self)) :pattern ((length self)) ))))
+
+;; "mixfix []"
+(define-fun mixfix_lbrb1 (par (a1)
+  ((a3 (array a1)) (i Int)) a1
+  (infix_at (elts a3) i)))
+
+;; "mixfix [<-]"
+(declare-fun mixfix_lblsmnrb1 (par (a1)
+  ((array a1)
+  Int
+  a1) (array a1)))
+
+;; "mixfix [<-]'spec"
+(assert (par (a1)
+  (forall ((a3 (array a1)) (i Int) (v a1))
+    (and
+      (= (length (mixfix_lblsmnrb1 a3 i v)) (length a3))
+      (= (elts (mixfix_lblsmnrb1 a3 i v)) (set (elts a3) i v))))))
+
+;; "make"
+(declare-fun make (par (a1)
+  (Int
+  a1) (array a1)))
+
+;; "make_spec"
+(assert (par (a1)
+  (forall ((n Int) (v a1))
+    (=>
+      (>= n 0)
+      (and
+        (forall ((i Int))
+          (=> (and (<= 0 i) (< i n)) (= (mixfix_lbrb1 (make n v) i) v)))
+        (= (length (make n v)) n))))))
+
+;; "numof"
+(declare-fun numof ((infix_mngt Int
+  Bool)
+  Int
+  Int) Int)
+
+;; "numof'def"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (ite (<= b2 a3)
+      (= (numof p a3 b2) 0)
+      (ite (= (infix_at p (- b2 1)) true)
+        (= (numof p a3 b2) (+ 1 (numof p a3 (- b2 1))))
+        (= (numof p a3 b2) (numof p a3 (- b2 1)))))))
+
+;; "Numof_bounds"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (< a3 b2)
+      (and (<= 0 (numof p a3 b2)) (<= (numof p a3 b2) (- b2 a3))))))
+
+;; "Numof_append"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int) (c Int))
+    (=>
+      (and (<= a3 b2) (<= b2 c))
+      (= (numof p a3 c) (+ (numof p a3 b2) (numof p b2 c))))))
+
+;; "Numof_left_no_add"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (< a3 b2)
+      (=>
+        (not (= (infix_at p a3) true))
+        (= (numof p a3 b2) (numof p (+ a3 1) b2))))))
+
+;; "Numof_left_add"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (< a3 b2)
+      (=>
+        (= (infix_at p a3) true)
+        (= (numof p a3 b2) (+ 1 (numof p (+ a3 1) b2)))))))
+
+;; "Empty"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (forall ((n Int))
+        (=> (and (<= a3 n) (< n b2)) (not (= (infix_at p n) true))))
+      (= (numof p a3 b2) 0))))
+
+;; "Full"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (<= a3 b2)
+      (=>
+        (forall ((n Int))
+          (=> (and (<= a3 n) (< n b2)) (= (infix_at p n) true)))
+        (= (numof p a3 b2) (- b2 a3))))))
+
+;; "numof_increasing"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (i Int) (j Int) (k Int))
+    (=> (and (<= i j) (<= j k)) (<= (numof p i j) (numof p i k)))))
+
+;; "numof_strictly_increasing"
+(assert
+  (forall ((p (infix_mngt Int Bool)) (i Int) (j Int) (k Int) (l Int))
+    (=>
+      (and (<= i j) (and (<= j k) (< k l)))
+      (=> (= (infix_at p k) true) (< (numof p i j) (numof p i l))))))
+
+;; "numof_change_any"
+(assert
+  (forall ((p1 (infix_mngt Int Bool)) (p2 (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (forall ((j Int))
+        (=>
+          (and (<= a3 j) (< j b2))
+          (=> (= (infix_at p1 j) true) (= (infix_at p2 j) true))))
+      (>= (numof p2 a3 b2) (numof p1 a3 b2)))))
+
+;; "numof_change_some"
+(assert
+  (forall ((p1 (infix_mngt Int Bool)) (p2 (infix_mngt Int Bool)) (a3 Int) (b2 Int) (i Int))
+    (=>
+      (and (<= a3 i) (< i b2))
+      (=>
+        (forall ((j Int))
+          (=>
+            (and (<= a3 j) (< j b2))
+            (=> (= (infix_at p1 j) true) (= (infix_at p2 j) true))))
+        (=>
+          (not (= (infix_at p1 i) true))
+          (=> (= (infix_at p2 i) true) (> (numof p2 a3 b2) (numof p1 a3 b2))))))))
+
+;; "numof_change_equiv"
+(assert
+  (forall ((p1 (infix_mngt Int Bool)) (p2 (infix_mngt Int Bool)) (a3 Int) (b2 Int))
+    (=>
+      (forall ((j Int))
+        (=>
+          (and (<= a3 j) (< j b2))
+          (= (= (infix_at p1 j) true) (= (infix_at p2 j) true))))
+      (= (numof p2 a3 b2) (numof p1 a3 b2)))))
+
+;; "fc"
+(declare-fun fc (par (a1)
+  ((array a1)
+  a1) (infix_mngt Int
+  Bool)))
+
+;; "fc'def"
+(assert (par (a1)
+  (forall ((a3 (array a1)) (v a1) (i Int))
+    (= (= (infix_at (fc a3 v) i) true) (= (mixfix_lbrb1 a3 i) v)))))
+
+;; "numof"
+(define-fun numof1 (par (a1)
+  ((a3 (array a1)) (v a1) (l Int) (u Int)) Int
+  (numof (fc a3 v) l u)))
+
+(declare-sort candidate 0)
+
+;; "a"
+(declare-fun a3 () (array candidate))
+
+;; "Requires"
+(assert (<= 1 (length a3)))
+
+;; "n"
+(define-fun n () Int
+  (length a3))
+
+;; "length'result'unused"
+(define-fun lengthqtresultqtunused () Int
+  n)
+
+;; "mixfix []'result'unused"
+(define-fun mixfix_lbrbqtresultqtunused () candidate
+  (mixfix_lbrb1 a3 0))
+
+;; "cand'unused"
+(define-fun candqtunused () (ref candidate)
+  (refqtmk (mixfix_lbrb1 a3 0)))
+
+;; "ref'result'unused"
+(define-fun refqtresultqtunused () (ref candidate)
+  (refqtmk (mixfix_lbrb1 a3 0)))
+
+;; "k'unused"
+(define-fun kqtunused () (ref Int)
+  (refqtmk 0))
+
+;; "ref'result'unused"
+(define-fun refqtresultqtunused1 () (ref Int)
+  (refqtmk 0))
+
+;; "infix -'result'unused"
+(define-fun infix_mnqtresultqtunused () Int
+  (- n 1))
+
+;; "H"
+(assert (<= 0 (+ (- n 1) 1)))
+
+;; "k"
+(declare-fun k () Int)
+
+;; "cand"
+(declare-fun cand () candidate)
+
+;; "cand'unused"
+(define-fun candqtunused1 () (ref candidate)
+  (refqtmk cand))
+
+;; "k"
+(define-fun k1 () (ref Int)
+  (refqtmk k))
+
+;; "i"
+(declare-fun i () Int)
+
+;; "H"
+(assert (<= 0 i))
+
+;; "H"
+(assert (<= i (- n 1)))
+
+;; "H"
+(assert (<= 0 k))
+
+;; "H"
+(assert (<= k (numof1 a3 cand 0 i)))
+
+;; "LoopInvariant"
+(assert (<= (* 2 (- (numof1 a3 cand 0 i) k)) (- i k)))
+
+;; "LoopInvariant"
+(assert
+  (forall ((c candidate))
+    (=> (not (= c cand)) (<= (* 2 (numof1 a3 c 0 i)) (- i k)))))
+
+;; "prefix !'result'unused"
+(define-fun prefix_exqtresultqtunused () Int
+  k)
+
+;; "infix ='result'unused"
+(define-fun infix_eqqtresultqtunused () Bool
+  (ite (= k 0) true false))
+
+;; "H"
+(assert (= k 0))
+
+;; "mixfix []'result'unused"
+(define-fun mixfix_lbrbqtresultqtunused1 () candidate
+  (mixfix_lbrb1 a3 i))
+
+;; "cand"
+(declare-fun cand1 () candidate)
+
+;; "cand'unused"
+(define-fun candqtunused2 () (ref candidate)
+  (refqtmk cand1))
+
+;; "Ensures"
+(assert (= cand1 (mixfix_lbrb1 a3 i)))
+
+;; "k"
+(declare-fun k2 () Int)
+
+;; "k'unused"
+(define-fun kqtunused1 () (ref Int)
+  (refqtmk k2))
+
+;; "Ensures"
+(assert (= k2 1))
+
+;; "H"
+(assert (<= 0 k2))
+
+;; "H"
+(assert (<= k2 (numof1 a3 cand1 0 (+ i 1))))
+
+;; "LoopInvariant"
+(assert (<= (* 2 (- (numof1 a3 cand1 0 (+ i 1)) k2)) (- (+ i 1) k2)))
+
+;; "c"
+(declare-fun c () candidate)
+
+;; "H"
+(assert (not (= c cand1)))
+
+;; "Hinst"
+;;(assert (=> (not (= c cand)) (<= (* 2 (numof1 a3 c 0 i)) (- i k))))
+
+;;(assert (= (numof1 a3 c 0 i) (numof (fc a3 c) 0 i)))
+;;(assert (= (numof (fc a3 c) 0 i) (numof (fc a3 c) 0 i)))
+;;(assert (= (numof1 a3 c 0 i) (numof1 a3 c 0 i)))
+;;(assert (= (numof (fc a3 c) 0 (+ i 1)) (numof (fc a3 c) 0 (+ i 1))))
+
+(assert (let ((p (fc a3 c))(a3 0)(b2 (+ i 1)))
+        (ite (<= b2 a3)
+        (= (numof p a3 b2) 0)
+           (ite (= (infix_at p (- b2 1)) true)
+           (= (numof p a3 b2) (+ 1 (numof p a3 (- b2 1))))
+           (= (numof p a3 b2) (numof p a3 (- b2 1)))))))
+
+;; Goal "mjrty'vc"
+;; File "/home/bobot/Sources/why3.master/examples/mjrty.mlw", line 36, characters 6-11
+(assert (not (<= (* 2 (numof1 a3 c 0 (+ i 1))) (- (+ i 1) k2))))
+
+(check-sat)
diff --git a/colibri2/theories/quantifier/definitions.ml b/colibri2/theories/quantifier/definitions.ml
index b34f825b3..af5c73471 100644
--- a/colibri2/theories/quantifier/definitions.ml
+++ b/colibri2/theories/quantifier/definitions.ml
@@ -35,6 +35,8 @@ let criteria d (_sym : Expr.Term.Const.t) _tyl _tvl (body : Expr.term) =
   | Expr.Binder (_, _) -> false
   | Expr.Match (_, _) -> false
 
+let nb_eager_folding = 25
+
 let handler d (sym : Expr.Term.Const.t) tyl tvl body =
   if criteria d sym tyl tvl body then
     let tys, tvs =
@@ -44,6 +46,7 @@ let handler d (sym : Expr.Term.Const.t) tyl tvl body =
       List.for_all (fun v -> Expr.Ty.Var.S.mem v tys) tyl
       && List.for_all (fun v -> Expr.Term.Var.S.mem v tvs) tvl
     then
+      let nb_folding = Context.Ref.create (Egraph.context d) (-1) in
       match Pattern.of_term_exn ~subst:Ground.Subst.empty body with
       | exception Pattern.Unconvertible -> ()
       | pat ->
@@ -63,7 +66,11 @@ let handler d (sym : Expr.Term.Const.t) tyl tvl body =
               Debug.dprintf2 Common.debug "Fold definition of %a" Ground.Term.pp
                 t;
               let n = Ground.node (Ground.index t) in
-              DaemonLastEffortUncontextual.schedule_immediately d (fun d ->
-                  Egraph.register d n))
+              if Context.Ref.get nb_folding < nb_eager_folding then (
+                Context.Ref.incr nb_folding;
+                Egraph.register d n)
+              else
+                DaemonLastEffortUncontextual.schedule_immediately d (fun d ->
+                    Egraph.register d n))
 
 let th_register d = Ground.Defs.add_handler d handler
-- 
GitLab