(*  This file is part of Colibrics.                                      *)
(*                                                                       *)
(*  Copyright (C) 2017                                                   *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)

open Colibri2_popop_lib
open Colibri2_core
open Colibri2_stdlib.Std

let debug = Debug.register_info_flag
  ~desc:"for the arithmetic theory of polynome"

let dom_poly = DomKind.create_key (module struct type t = Polynome.t let name = "ARITH_POLY" end)

module T = struct
  include Polynome
  let key = ThTermKind.create_key (module struct type nonrec t = t let name = "SARITH_POLY" end)

module ThE = ThTermKind.Register(T)

let used_in_poly : Node.t Bag.t Node.HC.t = Node.HC.create (Bag.pp Pp.semi Node.pp) "used_in_poly"

let set_poly d cl p =
  Egraph.set_dom d dom_poly cl p;
  match Polynome.is_one_node p with
  | None -> Egraph.set_thterm d cl (ThE.thterm (ThE.index p))
  | Some cl' -> Egraph.merge d cl cl'

let add_used d cl' new_cl =
  Node.M.iter (fun used _ ->
      Node.HC.change (function
          | Some b -> Some (Bag.append b cl')
          | None ->
            begin match Egraph.get_dom d dom_poly used with
              | None -> Egraph.set_dom d dom_poly used (Polynome.monome used)
              | Some p ->
                Format.eprintf "used=%a p=%a@." Node.pp used Polynome.pp p;
                assert (Polynome.equal (Polynome.monome used) p)
            Some (Bag.elt cl')
        ) used_in_poly d used
    ) new_cl

let subst_doms d cl (p:Polynome.t) =
  let b = match Node.HC.find used_in_poly d cl with
    | exception Not_found -> Bag.empty
    | b -> b
  Bag.iter (fun cl' ->
      match Egraph.get_dom d dom_poly cl' with
      | None -> assert false (** absurd: can't be used and absent *)
      | Some q ->
        let new_cl = Node.M.set_diff p.poly q.poly in
        let q, _ = Polynome.subst q cl p in
        add_used d cl' new_cl;
        set_poly d cl' q
    ) b;
  add_used d cl p.poly;
  set_poly d cl p

module Th = struct
  include Polynome

  let merged v1 v2 =
    match v1,v2 with
    | None, None -> true
    | Some v', Some v -> equal v' v
    | _ -> false

    let norm_dom cl = function
      | None ->
        let r = monome cl in
      | Some p ->

    let add_itself d cl norm =
      add_used d cl norm.poly;
      Egraph.set_dom d dom_poly cl norm

    let merge d ((p1o,cl1) as a1) ((p2o,cl2) as a2) inv =
      assert (not (Egraph.is_equal d cl1 cl2));
      assert (not (CCOpt.is_none p1o && CCOpt.is_none p2o));
      let (pother, other), (prepr, repr) = if inv then a2,a1 else a1,a2 in
      let other = Egraph.find d other in
      let repr = Egraph.find d repr in
      let p1 = norm_dom other pother in
      let p2 = norm_dom repr prepr in
      let diff = sub p1 p2 in
      (* 0 = other - repr = p1 - p2 = diff *)
      Debug.dprintf2 debug "[Arith] @[solve 0=%a@]" pp diff;
      begin match Polynome.extract diff with
      | Zero -> (** no new equality already equal *)
          match pother, prepr with
          | Some _, Some _ | None, None ->
            assert false (** absurd: no need of merge *)
          | Some p, None ->
            (** p = repr *)
            add_itself d repr p
          | None, Some p ->
            (** p = other *)
            add_itself d other p
      | Cst _ ->
        (* 0 = cst <> 0 *)
        Egraph.contradiction d
      | Var(q,x,p') ->
        (** diff = qx + p' *)
        assert ( not (Q.equal q) );
        Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
        let add_if_default n norm = function
          | Some _ -> ()
          | None ->
            add_itself d n norm
        add_if_default other p1 pother;
        add_if_default repr p2 prepr;
        subst_doms d x (Polynome.mult_cst (Q.div (Q.neg q)) p')
      assert (
                (Egraph.get_dom d dom_poly repr)
                (Egraph.get_dom d dom_poly other) = 0)

    let solve_one d cl p1 =
      match Egraph.get_dom d dom_poly cl with
      | None ->
        subst_doms d cl p1
      | Some p2 ->
        let diff = Polynome.sub p1 p2 in
        (* 0 = p1 - p2 = diff *)
        Debug.dprintf8 debug "[Arith] @[solve in init %a 0=(%a)-(%a)=%a@]"
          Node.pp cl Polynome.pp p1 Polynome.pp p2 Polynome.pp diff;
        begin match Polynome.extract diff with
          | Zero -> ()
          | Cst _ ->
            (* 0 = cst <> 0 *)
            Egraph.contradiction d
          | Var(q,x,p') ->
            (** diff = qx + p' *)
            assert ( not (Q.equal q) );
            Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
            subst_doms d x (Polynome.mult_cst (Q.div (Q.neg q)) p')

    let key = dom_poly

let () = Egraph.register_dom(module Th)

let assume_poly_equality d n p =
  Debug.dprintf4 debug "assume1: %a = %a" Node.pp n Polynome.pp p;
  let n = Egraph.find d n in
  let add acc cl c =
    let cl = Egraph.find d cl in
    match Egraph.get_dom d dom_poly cl with
    | None -> Polynome.add acc (Polynome.monome c cl)
    | Some p -> Polynome.x_p_cy acc c p
  let p = Polynome.fold add p in
  Debug.dprintf4 debug "assume2: %a = %a" Node.pp n Polynome.pp p;
  Th.solve_one d n p

let solve_one = Th.solve_one