From da2b83b84937a8ab8fee65c9d78234673310df03 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Mon, 18 Jan 2021 15:20:59 +0100
Subject: [PATCH] [Product] Fix comparison with zero of multiplication

---
 src_colibri2/stdlib/std.ml                    |  2 +
 src_colibri2/stdlib/std.mli                   |  2 +
 .../tests/solve/smt_nra/unsat/dune.inc        |  2 +
 .../solve/smt_nra/unsat/mul_pos_zero_le.smt2  | 13 +++
 src_colibri2/theories/LRA/dom_polynome.ml     | 15 ++-
 src_colibri2/theories/LRA/dom_polynome.mli    |  4 +
 src_colibri2/theories/LRA/dom_product.ml      | 97 +++++++++++++------
 src_colibri2/theories/LRA/fourier.ml          | 22 +++--
 src_colibri2/theories/LRA/product.ml          | 14 ++-
 src_common/q.mlw                              | 13 +--
 10 files changed, 123 insertions(+), 61 deletions(-)
 create mode 100644 src_colibri2/tests/solve/smt_nra/unsat/mul_pos_zero_le.smt2

diff --git a/src_colibri2/stdlib/std.ml b/src_colibri2/stdlib/std.ml
index a488e3b04..5304d0901 100644
--- a/src_colibri2/stdlib/std.ml
+++ b/src_colibri2/stdlib/std.ml
@@ -105,4 +105,6 @@ module Q = struct
   let floor x = Q.of_bigint (Colibrics_lib.QUtils.floor x)
   let ceil x = Q.of_bigint (Colibrics_lib.QUtils.ceil x)
 
+  let none_zero c = if Q.equal Q.zero c then None else Some c
+
 end
diff --git a/src_colibri2/stdlib/std.mli b/src_colibri2/stdlib/std.mli
index f59932580..c559656cc 100644
--- a/src_colibri2/stdlib/std.mli
+++ b/src_colibri2/stdlib/std.mli
@@ -60,4 +60,6 @@ module Q : sig
   val floor: t -> t
   val ceil : t -> t
   val is_integer : t -> bool
+  val none_zero: t -> t option
+  (** return None if the input is zero otherwise Some of the value *)
 end
diff --git a/src_colibri2/tests/solve/smt_nra/unsat/dune.inc b/src_colibri2/tests/solve/smt_nra/unsat/dune.inc
index c97ea8f02..d4a1c682d 100644
--- a/src_colibri2/tests/solve/smt_nra/unsat/dune.inc
+++ b/src_colibri2/tests/solve/smt_nra/unsat/dune.inc
@@ -13,3 +13,5 @@
 (rule (alias runtest) (action (diff oracle mul_pos.smt2.res)))
 (rule (action (with-stdout-to mul_pos_lt.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:mul_pos_lt.smt2}))))
 (rule (alias runtest) (action (diff oracle mul_pos_lt.smt2.res)))
+(rule (action (with-stdout-to mul_pos_zero_le.smt2.res (run %{bin:colibri2} --max-step 1300 %{dep:mul_pos_zero_le.smt2}))))
+(rule (alias runtest) (action (diff oracle mul_pos_zero_le.smt2.res)))
diff --git a/src_colibri2/tests/solve/smt_nra/unsat/mul_pos_zero_le.smt2 b/src_colibri2/tests/solve/smt_nra/unsat/mul_pos_zero_le.smt2
new file mode 100644
index 000000000..d0ab3f7a1
--- /dev/null
+++ b/src_colibri2/tests/solve/smt_nra/unsat/mul_pos_zero_le.smt2
@@ -0,0 +1,13 @@
+(set-logic ALL)
+
+(declare-const a Real)
+(declare-const b Real)
+(declare-const c Real)
+
+(assert (< 0 b))
+
+(assert (<= 0.0 (- (* a b) (* c b))))
+
+(assert (not (<= 0.0 (- a c))))
+
+(check-sat)
diff --git a/src_colibri2/theories/LRA/dom_polynome.ml b/src_colibri2/theories/LRA/dom_polynome.ml
index b0b77f54d..ca059d159 100644
--- a/src_colibri2/theories/LRA/dom_polynome.ml
+++ b/src_colibri2/theories/LRA/dom_polynome.ml
@@ -35,6 +35,8 @@ end
 
 module ThE = ThTermKind.Register(T)
 
+let node_of_polynome t = ThE.node (ThE.index t)
+
 let used_in_poly : Node.t Bag.t Node.HC.t = Node.HC.create (Bag.pp Pp.semi Node.pp) "used_in_poly"
 
 let set_poly d cl p =
@@ -169,16 +171,19 @@ end
 
 let () = Egraph.register_dom(module Th)
 
-let assume_poly_equality d n (p:Polynome.t) =
-  (* Debug.dprintf4 debug "assume1: %a = %a" Node.pp n Polynome.pp p; *)
-  let n = Egraph.find d n in
+let norm d (p:Polynome.t) =
   let add acc cl c =
-    let cl = Egraph.find d cl in
+    let cl = Egraph.find_def d cl in
     match Egraph.get_dom d dom cl with
     | None -> Polynome.add acc (Polynome.monome c cl)
     | Some p -> Polynome.x_p_cy acc c p
   in
-  let p = Polynome.fold add (Polynome.cst p.cst) p in
+  Polynome.fold add (Polynome.cst p.cst) p
+
+let assume_poly_equality d n (p:Polynome.t) =
+  (* Debug.dprintf4 debug "assume1: %a = %a" Node.pp n Polynome.pp p; *)
+  let n = Egraph.find_def d n in
+  let p = norm d p in
   (* Debug.dprintf4 debug "assume2: %a = %a" Node.pp n Polynome.pp p; *)
   Th.solve_one d n p
 
diff --git a/src_colibri2/theories/LRA/dom_polynome.mli b/src_colibri2/theories/LRA/dom_polynome.mli
index 32c8b9a1c..4c3518172 100644
--- a/src_colibri2/theories/LRA/dom_polynome.mli
+++ b/src_colibri2/theories/LRA/dom_polynome.mli
@@ -23,3 +23,7 @@ val assume_poly_equality: Egraph.t -> Node.t -> Polynome.t -> unit
 val dom: Polynome.t DomKind.t
 
 val init: Egraph.t -> unit
+
+val node_of_polynome: Polynome.t -> Node.t
+
+val norm: Egraph.t -> Polynome.t -> Polynome.t
diff --git a/src_colibri2/theories/LRA/dom_product.ml b/src_colibri2/theories/LRA/dom_product.ml
index 85394ea1b..e162fde03 100644
--- a/src_colibri2/theories/LRA/dom_product.ml
+++ b/src_colibri2/theories/LRA/dom_product.ml
@@ -188,33 +188,36 @@ module Th = struct
         add_used_product d cl eq.Product.poly;
         set_poly d cl { repr = eq; eqs = Product.S.empty } [eq]
       | Some p ->
-        match solve d
-                (Lists.product
-                   (part d (p.repr::(Product.S.elements p.eqs)))
-                   (part d [eq])
-                )
-        with
-        | `Solved ->
-          (** The domains have been subsituted, and possibly recursively *)
-          let eq = norm_product d eq in
-          merge_one_new_eq d cl eq
-        | `Not_solved ->
-          (** nothing to solve *)
-          let repr = p.repr in
-          let eqs = p.eqs
-                    |> Product.S.add eq
-                    |> Product.S.remove repr
-          in
-          let p = { repr; eqs } in
-          add_used_product d cl eq.poly;
-          set_poly d cl p [eq]
+        if not (Product.S.mem eq p.eqs) && not (Product.equal eq p.repr) then
+          match solve d
+                  (Lists.product
+                     (part d (p.repr::(Product.S.elements p.eqs)))
+                     (part d [eq])
+                  )
+          with
+          | `Solved ->
+            (** The domains have been subsituted, and possibly recursively *)
+            let eq = norm_product d eq in
+            merge_one_new_eq d cl eq
+          | `Not_solved ->
+            (** nothing to solve *)
+            let repr = p.repr in
+            let eqs = p.eqs
+                      |> Product.S.add eq
+                      |> Product.S.remove repr
+            in
+            let p = { repr; eqs } in
+            add_used_product d cl eq.poly;
+            set_poly d cl p [eq]
 
     and subst d cl eq =
       Debug.dprintf4 debug "[Product] subst %a with %a" Node.pp cl Product.pp eq;
       let po = Egraph.get_dom d dom cl in
       match po with
       | None ->
-        assert false (* used in another equation *)
+        let p = { repr = eq; eqs = Product.S.empty } in
+        add_used_product d cl eq.poly;
+        set_poly d cl p [eq]
       | Some p ->
         assert (Product.equal p.repr (Product.monome cl Q.one));
         subst_doms d cl eq;
@@ -321,14 +324,35 @@ end
 
 let () = Egraph.Wait.register_dem (module ChangePos)
 
-let init_den,attach_den =
-  Demon.Fast.register_change_domain_daemon
-    ~name:"Dom_product.den"
-    Dom_interval.dom
-    (fun d _ (res,num,den) ->
-       if Dom_interval.is_not_zero d den then
-         assume_poly_equality d num (Product.of_list [res,Q.one;den,Q.one])
-    )
+let factorize res a coef b d _ =
+  match Egraph.get_dom d dom a, Egraph.get_dom d dom b with
+  | Some pa, Some pb ->
+    let common = Node.M.inter (fun _ a b ->
+        Q.none_zero (Q.min (Q.floor a) (Q.floor b)))
+        pa.repr.poly pb.repr.poly in
+    if not (Node.M.is_empty common) then
+      let (cst,l) = List.fold_left (fun (cst,acc) (p,v) ->
+          let p = Product.of_map (Node.M.diff (fun _ a b -> Q.none_zero (Q.sub a b)) p.Product.poly common) in
+          match Product.classify p with
+          | ONE ->
+            (Q.add v cst,acc)
+          | NODE n ->
+            (cst,(n,v)::acc)
+          | PRODUCT ->
+            let n = node_of_product p in
+            Egraph.register d n;
+            (Q.add Q.one cst,(n,v)::acc)
+        ) (Q.zero,[]) [pa.repr,Q.one;pb.repr,coef] in
+      let p = Polynome.of_list cst l in
+      let n = Dom_polynome.node_of_polynome p in
+      Egraph.register d n;
+      Dom_polynome.assume_poly_equality d n p;
+      assume_poly_equality d res (Product.of_map (Node.M.add n Q.one common))
+  | _ -> ()
+
+
+let init_dem,wait =
+  Demon.Fast.register_simply "GotDomInterval"
 
 (** {2 Initialization} *)
 let converter d (f:Ground.t) =
@@ -340,7 +364,17 @@ let converter d (f:Ground.t) =
     assume_poly_equality d res (Product.of_list [a,Q.one;b,Q.one])
   | { app = {builtin = Expr.Div}; tyargs = []; args = [num;den]; _ } ->
     reg num; reg den;
-    attach_den d den (res,num,den)
+    wait.for_dom d Dom_interval.dom den
+      (fun d _ ->
+         if Dom_interval.is_not_zero d den then
+           assume_poly_equality d num (Product.of_list [res,Q.one;den,Q.one])
+      )
+  | { app = {builtin = Expr.Add}; tyargs = []; args = [a;b]; _ } ->
+    reg a; reg b;
+    List.iter (fun x -> wait.for_dom d dom x (factorize res a Q.one b)) [a;b]
+  | { app = {builtin = Expr.Sub}; tyargs = []; args = [a;b]; _ } ->
+    reg a; reg b;
+    List.iter (fun x -> wait.for_dom d dom x (factorize res a Q.minus_one b)) [a;b]
   | _ -> ()
 
 let init env =
@@ -375,6 +409,5 @@ let init env =
            Product.S.iter aux p.eqs
       ) env;
     Ground.register_converter env converter;
-    init_den env;
+    init_dem env;
     Egraph.attach_any_dom env Dom_interval.dom ChangePos.key ()
-
diff --git a/src_colibri2/theories/LRA/fourier.ml b/src_colibri2/theories/LRA/fourier.ml
index f969d3f70..03fbeeadb 100644
--- a/src_colibri2/theories/LRA/fourier.ml
+++ b/src_colibri2/theories/LRA/fourier.ml
@@ -28,7 +28,8 @@ let mk_eq d bound truth a b =
     | Some p -> p in
   let bound = if truth then bound else match bound with | Interval.Bound.Large -> Strict | Strict -> Large in
   let a,b = if truth then a,b else b,a in
-  Polynome.sub (!a) (!b), bound
+  Polynome.sub (!a) (!b), bound,
+  Polynome.sub (Polynome.monome Q.one a) (Polynome.monome Q.one b)
 
 let divide d (p:Polynome.t) =
   try
@@ -38,11 +39,16 @@ let divide d (p:Polynome.t) =
     end;
     if Q.sign p.cst <> 0 then raise Exit;
     let l = Node.M.bindings p.poly in
-    let l = List.map (fun (e,q) ->
-        Opt.get_exn Exit (Dom_product.get_repr d e),q) l in
+    let l = List.fold_left (fun acc (e,q) ->
+        match Dom_product.get_repr d e with
+        | None when Egraph.is_equal d RealValue.zero e -> acc
+        | None -> raise Exit
+        | Some p -> (p,q)::acc) [] l in
     Debug.dprintf4 debug "@[eq:%a@ %a@]"
       Polynome.pp p Fmt.(list ~sep:(any "+") (using Pair.swap (pair Q.pp Product.pp))) l;
-    let hd,_ = List.hd l in
+    match l with
+    | [] -> raise Exit
+    | (hd,_)::_ ->
     let common = List.fold_left  (fun acc (p,_) ->
         Node.M.inter (fun _ a b -> if Q.equal a b then Some a else None) acc p.Product.poly)
         hd.Product.poly l in
@@ -60,6 +66,7 @@ let divide d (p:Polynome.t) =
       Polynome.pp p (Node.M.pp Q.pp) common;
     let (cst,l) = List.fold_left (fun (cst,acc) (p,v) ->
         let p = Product.of_map (Node.M.set_diff p.Product.poly common) in
+
         let v = if pos then v else Q.neg v in
         match Product.classify p with
         | ONE ->
@@ -80,7 +87,7 @@ let make_equations d (eqs,vars) g =
   | None ->
     (eqs,vars)
   | Some truth ->
-    let p,bound =
+    let p,bound,p_non_norm =
       match Ground.sem g with
       | { app = {builtin = Expr.Lt}; tyargs = []; args = [a;b]; _ } ->
         mk_eq d Strict truth a b
@@ -92,13 +99,14 @@ let make_equations d (eqs,vars) g =
         mk_eq d Strict truth b a
       | _ -> assert false
     in
-    Debug.dprintf3 debug "[Fourier] %b %a" truth Polynome.pp p;
+    Debug.dprintf5 debug "[Fourier] %b %a(%a)" truth Polynome.pp p Polynome.pp p_non_norm;
     let eqs, vars =
       (add_eq d eqs p bound (Ground.S.singleton g),
        Node.M.union_merge (fun _ _ _ -> Some ()) vars p.Polynome.poly)
     in
-    match divide d p with
+    match divide d p_non_norm with
     | Some (p',_,_) ->
+      let p' = Dom_polynome.norm d p' in
       (add_eq d eqs p' bound Ground.S.empty,
        Node.M.union_merge (fun _ _ _ -> Some ()) vars p'.Polynome.poly)
     | None -> eqs, vars
diff --git a/src_colibri2/theories/LRA/product.ml b/src_colibri2/theories/LRA/product.ml
index 94fd0a455..40da25d94 100644
--- a/src_colibri2/theories/LRA/product.ml
+++ b/src_colibri2/theories/LRA/product.ml
@@ -109,8 +109,6 @@ let is_one_node p = (** cst = 0 and one empty monome *)
  *   assert (not (Q.equal c Q.zero));
  *   { p1 with cst = Q.mul p1.cst c } *)
 
-let none_zero c = if Q.equal Q.zero c then None else Some c
-
 let power_cst c p1 =
   if Q.equal Q.zero c then one
   else if Q.equal Q.one c then p1
@@ -123,7 +121,7 @@ let power_cst c p1 =
     that verifies [op 0 p = p] and [op p 0 = p] *)
 let mul p1 p2 =
   let poly_add m1 m2 =
-    Node.M.union (fun _ c1 c2 -> none_zero (Q.add c1 c2)) m1 m2
+    Node.M.union (fun _ c1 c2 -> Q.none_zero (Q.add c1 c2)) m1 m2
   in
   {(* cst = Q.mul p1.cst p2.cst; *) poly = poly_add p1.poly p2.poly}
 
@@ -132,7 +130,7 @@ let div p1 p2 =
     Node.M.union_merge (fun _ c1 c2 ->
       match c1 with
       | None -> Some (Q.neg c2)
-      | Some c1 -> none_zero (Q.sub c1 c2))
+      | Some c1 -> Q.none_zero (Q.sub c1 c2))
       m1 m2 in
   {(* cst = Q.div p1.cst p2.cst; *) poly = poly_sub p1.poly p2.poly}
 
@@ -143,7 +141,7 @@ let x_m_yc p1 p2 c =
     Node.M.union_merge (fun _ c1 c2 ->
       match c1 with
       | None -> Some (Q.mul c c2)
-      | Some c1 -> none_zero (f c1 c2))
+      | Some c1 -> Q.none_zero (f c1 c2))
       m1 m2 in
   {(* cst = Q.mul p1.cst (Q. p2.cst; *) poly = poly p1.poly p2.poly}
 
@@ -165,7 +163,7 @@ let xc_m_yc p1 c1 p2 c2 =
           | None, Some e2 -> Some (Q.mul c2 e2)
           | Some e1, None -> Some (Q.mul c1 e1)
           | Some e1, Some e2 ->
-            none_zero (f e1 e2))
+            Q.none_zero (f e1 e2))
         m1 m2 in
     {(* cst = f p1.cst p2.cst; *) poly = poly p1.poly p2.poly}
 
@@ -176,7 +174,7 @@ let subst_node p x y =
   | Some q ->
     let poly = Node.M.change (function
         | None -> qo
-        | Some q' -> none_zero (Q.add q q')
+        | Some q' -> Q.none_zero (Q.add q q')
       ) y poly in
     {poly}, q
 
@@ -192,7 +190,7 @@ let iter f p = Node.M.iter f p.poly
 let of_list (* cst  *)l =
   let fold acc (node,q) = Node.M.change (function
       | None -> Some q
-      | Some q' -> none_zero (Q.add q q')) node acc in
+      | Some q' -> Q.none_zero (Q.add q q')) node acc in
   {(* cst; *)poly= List.fold_left fold Node.M.empty l}
 
 let of_map m =
diff --git a/src_common/q.mlw b/src_common/q.mlw
index df4d1504c..986842aee 100644
--- a/src_common/q.mlw
+++ b/src_common/q.mlw
@@ -361,12 +361,11 @@ module Q
 
      let floor (a:t) : int
        ensures { result = floor a.real } =
-        assert { from_int (floor a.real) <=. a.real <. from_int (floor a.real) +. 1. };
-        assert { 0. <=. a.real -. from_int (floor a.real) <. 1. };
+(*        assert { from_int (floor a.real) <=. a.real <. from_int (floor a.real) +. 1. }; *)
+(*        assert { 0. <=. a.real -. from_int (floor a.real) <. 1. }; *)
         let r = Z.fdiv a.num a.den in
-        assert { Int.(0 <= a.num - r*a.den < a.den) };
-        assert { 0. <=. (from_int a.num) -. (from_int r)*.(from_int a.den) <. (from_int a.den) };
-        assert { ((from_int a.num) -. (from_int r)*.(from_int a.den))/.(from_int a.den) <. (from_int a.den)/.(from_int a.den) };
+(*        assert { Int.(0 <= a.num - r*a.den < a.den) }; *)
+(*        assert { 0. <=. (from_int a.num) -. (from_int r)*.(from_int a.den) <. (from_int a.den) }; *)
         assert { 0. <=. a.real -. from_int r <. 1. };
         r
 
@@ -377,10 +376,6 @@ module Q
         assert { 0. <=. from_int (ceil a.real) -. a.real  <. 1. };
         let r = Z.cdiv a.num a.den in
         assert { Int.(0 <= r*a.den - a.num < a.den) };
-        assert { 0. <=. (from_int r)*.(from_int a.den) -. (from_int a.num) <. (from_int a.den) };
-        assert { (from_int r) -. (from_int a.num)/.(from_int a.den) <. 1. };
-        assert { (from_int r) -. (from_int a.num)/.(from_int a.den) <. (from_int a.den)/.(from_int a.den) };
-        assert { ((from_int r)*.(from_int a.den) -. (from_int a.num))/.(from_int a.den) <. (from_int a.den)/.(from_int a.den) };
         assert { 0. <=. from_int r -. a.real <. 1. };
         r
 end
-- 
GitLab