Skip to content
Snippets Groups Projects
dom_polynome.ml 8.11 KiB
Newer Older
(*************************************************************************)
François Bobot's avatar
François Bobot committed
(*  This file is part of Colibri2.                                       *)
François Bobot's avatar
François Bobot committed
(*  Copyright (C) 2014-2021                                              *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
(*************************************************************************)

open Colibri2_popop_lib
open Colibri2_core
open Colibri2_stdlib.Std

let debug = Debug.register_info_flag
  ~desc:"for the arithmetic theory of polynome"
  "LRA.polynome"

let dom = DomKind.create_key (module struct type t = Polynome.t let name = "ARITH_POLY" end)

module T = struct
  include Polynome
  let key = ThTermKind.create_key (module struct type nonrec t = t let name = "SARITH_POLY" end)
end

module ThE = ThTermKind.Register(T)

let node_of_polynome t = ThE.node (ThE.index t)

let used_in_poly : Node.t Bag.t Node.HC.t = Node.HC.create (Bag.pp Pp.semi Node.pp) "used_in_poly"

let set_poly d cl p =
  Egraph.set_dom d dom cl p;
  match Polynome.is_one_node p with
  | None -> Egraph.set_thterm d cl (ThE.thterm (ThE.index p))
  | Some cl' -> Egraph.merge d cl cl'

let add_used d cl' new_cl =
  Node.M.iter (fun used _ ->
      Node.HC.change (function
          | Some b -> Some (Bag.append b cl')
          | None ->
            begin match Egraph.get_dom d dom used with
              | None ->
                (** If a used node have no polynome associated, we set it to
                   itself. This allows to be warned when this node is merged. It
                   is the reason why this module doesn't specifically wait for
                   representative change *)
                Egraph.set_dom d dom used (Polynome.monome Q.one used)
              | Some p ->
                assert (Polynome.equal (Polynome.monome Q.one used) p)
            end;
            Some (Bag.elt cl')
        ) used_in_poly d used
    ) new_cl

let subst_doms d cl (p:Polynome.t) =
  let b = match Node.HC.find used_in_poly d cl with
    | exception Not_found -> Bag.empty
    | b -> b
  in
  Bag.iter (fun cl' ->
      match Egraph.get_dom d dom cl' with
      | None -> assert false (** absurd: can't be used and absent *)
      | Some q ->
        let new_cl = Node.M.set_diff p.poly q.poly in
        let q, _ = Polynome.subst q cl p in
        add_used d cl' new_cl;
        set_poly d cl' q
    ) b;
  add_used d cl p.poly;
  set_poly d cl p

module Th = struct
  include Polynome

  let merged v1 v2 =
    match v1,v2 with
    | None, None -> true
    | Some v', Some v -> equal v' v
    | _ -> false

    let norm_dom cl = function
      | None ->
        let r = monome Q.one cl in
        r
      | Some p ->
        p

    let add_itself d cl norm =
      add_used d cl norm.poly;
      Egraph.set_dom d dom cl norm

    let merge d ((p1o,cl1) as a1) ((p2o,cl2) as a2) inv =
      assert (not (Egraph.is_equal d cl1 cl2));
      assert (not (Option.is_none p1o && Option.is_none p2o));
      let (pother, other), (prepr, repr) = if inv then a2,a1 else a1,a2 in
      let other = Egraph.find d other in
      let repr = Egraph.find d repr in
      let p1 = norm_dom other pother in
      let p2 = norm_dom repr prepr in
      let diff = sub p1 p2 in
      (* 0 = other - repr = p1 - p2 = diff *)
      Debug.dprintf2 debug "[Arith] @[solve 0=%a@]" pp diff;
      begin match Polynome.extract diff with
      | Zero -> (** no new equality already equal *)
        begin
          match pother, prepr with
          | Some _, Some _ | None, None ->
            assert false (** absurd: no need of merge *)
          | Some p, None ->
            (** p = repr *)
            add_itself d repr p
          | None, Some p ->
            (** p = other *)
            add_itself d other p
        end
        (* 0 = cst <> 0 *)
        Debug.dprintf6 Egraph.print_contradiction
          "[LRA/Poly] Found 0 = %a when merging %a and %a"
          Q.pp c
          Node.pp cl1 Node.pp cl2;
        Egraph.contradiction d
      | Var(q,x,p') ->
        (** diff = qx + p' *)
        assert ( not (Q.equal Q.zero q) );
        Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
        let add_if_default n norm = function
          | Some _ -> ()
          | None ->
            add_itself d n norm
        in
        add_if_default other p1 pother;
        add_if_default repr p2 prepr;
        subst_doms d x (Polynome.mult_cst (Q.div Q.one (Q.neg q)) p')
      end;
      assert (Option.compare Polynome.compare
                (Egraph.get_dom d dom repr)
                (Egraph.get_dom d dom other) = 0)

    let solve_one d cl p1 =
      match Egraph.get_dom d dom cl with
      | None ->
        subst_doms d cl p1
      | Some p2 ->
        let diff = Polynome.sub p1 p2 in
        (* 0 = p1 - p2 = diff *)
        Debug.dprintf8 debug "[Arith] @[solve in init %a 0=(%a)-(%a)=%a@]"
          Node.pp cl Polynome.pp p1 Polynome.pp p2 Polynome.pp diff;
        begin match Polynome.extract diff with
          | Zero -> ()
            (* 0 = cst <> 0 *)
            Debug.dprintf4 Egraph.print_contradiction
              "[LRA/Poly] Found 0 = %a when updating %a"
              Q.pp c
              Node.pp cl;
            Egraph.contradiction d
          | Var(q,x,p') ->
            (** diff = qx + p' *)
            assert ( not (Q.equal Q.zero q) );
            Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
            subst_doms d x (Polynome.mult_cst (Q.div Q.one (Q.neg q)) p')
        end

end

let () = Egraph.register_dom(module Th)

let norm d (p:Polynome.t) =
  let add acc cl c =
    let cl = Egraph.find_def d cl in
    match Egraph.get_dom d dom cl with
    | None -> Polynome.add acc (Polynome.monome c cl)
    | Some p -> Polynome.x_p_cy acc c p
  in
  Polynome.fold add (Polynome.cst p.cst) p

let assume_poly_equality d n (p:Polynome.t) =
  (* Debug.dprintf4 debug "assume1: %a = %a" Node.pp n Polynome.pp p; *)
  let n = Egraph.find_def d n in
  let p = norm d p in
  (* Debug.dprintf4 debug "assume2: %a = %a" Node.pp n Polynome.pp p; *)
  Th.solve_one d n p

(** {2 Initialization} *)
let converter d (f:Ground.t) =
    let res = Ground.node f in
  let reg n = Egraph.register d n in
  match Ground.sem f with
  | { app = {builtin = Expr.Add}; tyargs = []; args; _ } ->
    let a,b = IArray.extract2_exn args in
    reg a; reg b;
    assume_poly_equality d res (Polynome.of_list Q.zero [a,Q.one;b,Q.one])
  | { app = {builtin = Expr.Sub}; tyargs = []; args; _ } ->
    let a,b = IArray.extract2_exn args in
    reg a; reg b;
    assume_poly_equality d res (Polynome.of_list Q.zero [a,Q.one;b,Q.minus_one])
  | { app = {builtin = Expr.Minus}; tyargs = []; args; _ } ->
    let a = IArray.extract1_exn args in
    reg a;
    assume_poly_equality d res (Polynome.of_list Q.zero [a,Q.minus_one])
  | _ -> ()

let init env =
    Demon.Fast.register_init_daemon_value
    ~name:"RealValueToDomPoly"
    (module RealValue)
    (fun d value ->
       let v = RealValue.value value in
       let cl = RealValue.node value in
       let p1 = Polynome.of_list v [] in
       assume_poly_equality d cl p1
    ) env;
    Ground.register_converter env converter