(*************************************************************************)
(*  This file is part of Colibri2.                                       *)
(*                                                                       *)
(*  Copyright (C) 2014-2021                                              *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
(*************************************************************************)

let debug =
  Debug.register_info_flag ~desc:"for the normalization by pivoting" "LRA.pivot"

type 'a solve_with_unsolved =
  | AlreadyEqual
  | Contradiction
  | Unsolved
  | Subst of 'a Node.M.t

module WithUnsolved (P : sig
  type t

  val name : string

  include Colibri2_popop_lib.Popop_stdlib.Datatype with type t := t

  val of_one_node : _ Egraph.t -> Node.t -> t
  val is_one_node : t -> Node.t option
  val subst : t -> Node.t -> t -> t option
  val normalize : t -> f:(Node.t -> t) -> t

  type data

  val nodes : t -> data Node.M.t

  type info

  val info : _ Egraph.t -> t -> info

  val attach_info_change :
    Egraph.wt -> (Egraph.rt -> Node.t -> Events.enqueue) -> unit

  val solve : info -> info -> t solve_with_unsolved
  val set : Egraph.wt -> Node.t -> old_:t -> new_:t -> unit
end) : sig
  val assume_equality : Egraph.wt -> Node.t -> P.t -> unit
  val init : Egraph.wt -> unit
  val get_repr : _ Egraph.t -> Node.t -> P.t option
  val iter_eqs : _ Egraph.t -> Node.t -> f:(P.t -> unit) -> unit

  val attach_repr_change :
    _ Egraph.t -> ?node:Node.t -> (Egraph.wt -> Node.t -> unit) -> unit

  val attach_eqs_change :
    _ Egraph.t -> ?node:Node.t -> (Egraph.wt -> Node.t -> unit) -> unit

  val reshape : Egraph.wt -> Node.t -> f:(P.t -> P.t option) -> unit
end = struct
  type t = {
    repr : P.t;
        (** When a pivot is not possible because the equality can be null, the other
            product are waiting on the side, they are also normalized *)
    eqs : P.S.t;
  }
  [@@deriving ord, eq]

  let pp fmt t = Fmt.pf fmt "%a,%a" P.pp t.repr P.S.pp t.eqs

  let dom =
    Dom.Kind.create
      (module struct
        type nonrec t = t

        let name = P.name
      end)

  let used_in_poly : Node.S.t Node.HC.t =
    Node.HC.create Node.S.pp (P.name ^ "_used_in_poly")

  let set_poly d cl p chg =
    Egraph.set_dom d dom cl p;
    List.iter (fun (old_, new_) -> P.set d cl ~old_ ~new_) chg

  let add_used_product d cl' new_cls =
    Node.M.iter
      (fun used _ ->
        Node.HC.change
          (function
            | Some b -> Some (Node.S.add cl' b)
            | None -> (
                match Egraph.get_dom d dom used with
                | None ->
                    (* If a used node have no polynome associated, we set it to
                        itself. This allows to be warned when this node is merged. It
                        is the reason why this module doesn't specifically wait for
                        representative change *)
                    Egraph.set_dom d dom used
                      { repr = P.of_one_node d used; eqs = P.S.empty };
                    Some (Node.S.of_list [ cl'; used ])
                | Some p ->
                    assert (
                      Option.equal Node.equal (P.is_one_node p.repr) (Some used));
                    assert false))
          used_in_poly d used)
      new_cls

  let add_used_t d cl' t =
    add_used_product d cl' (P.nodes t.repr);
    P.S.iter (fun p -> add_used_product d cl' (P.nodes p)) t.eqs

  let norm_product d p =
    P.normalize p ~f:(fun cl ->
        let cl = Egraph.find d cl in
        match Egraph.get_dom d dom cl with
        | None -> P.of_one_node d cl
        | Some p -> p.repr)

  let norm_dom d cl = function
    | None ->
        let r = P.of_one_node d cl in
        { repr = r; eqs = P.S.empty }
    | Some p -> p

  module Th = struct
    let merged v1 v2 =
      Base.phys_equal v1 v2
      ||
      match (v1, v2) with
      | None, None -> true
      | Some v', Some v -> equal v' v
      | _ -> false

    let add_itself d cl norm =
      add_used_t d cl norm;
      Egraph.set_dom d dom cl norm

    let rec merge d (_, cl1) (_, cl2) _inv =
      let cl1 = Egraph.find d cl1 in
      let cl2 = Egraph.find d cl2 in
      assert (not (Egraph.is_equal d cl1 cl2));
      merge_aux d cl1 cl2

    and merge_aux d cl1 cl2 =
      let p1o = Egraph.get_dom d dom cl1 in
      let p2o = Egraph.get_dom d dom cl2 in
      assert (not (Option.is_none p1o && Option.is_none p2o));
      match (p1o, p2o) with
      | None, None -> assert false (* absurd: no need to merge *)
      | Some p, None ->
          assert (Option.is_none (Node.HC.find_opt used_in_poly d cl2));
          add_itself d cl2 p
      | None, Some p ->
          assert (Option.is_none (Node.HC.find_opt used_in_poly d cl1));
          add_itself d cl1 p
      | Some p1, Some p2 -> (
          match
            solve d
              (Base.List.cartesian_product
                 (part d (p1.repr :: P.S.elements p1.eqs))
                 (part d (p2.repr :: P.S.elements p2.eqs)))
          with
          | `Solved ->
              (* The domains have been substituted, and possibly recursively *)
              merge_aux d cl1 cl2
          | `Not_solved ->
              (* nothing to solve *)
              let repr =
                match (P.is_one_node p1.repr, P.is_one_node p2.repr) with
                | None, None -> p1.repr (* arbitrary *)
                | Some _, None -> p1.repr
                | None, Some _ -> p2.repr
                | Some cl1', Some cl2' ->
                    assert (Node.equal cl1' cl2');
                    p1.repr
              in
              let eqs =
                p1.eqs |> P.S.add p1.repr |> P.S.union p2.eqs |> P.S.add p2.repr
                |> P.S.remove repr
              in
              let p = { repr; eqs } in
              Egraph.set_dom d dom cl1 p;
              Egraph.set_dom d dom cl2 p)

    and merge_one_new_eq d cl eq =
      let eq = norm_product d eq in
      let po = Egraph.get_dom d dom cl in
      if Option.is_some po || Node.M.mem cl (P.nodes eq) then (
        let p = norm_dom d cl po in
        if (not (P.S.mem eq p.eqs)) && not (P.equal eq p.repr) then
          match
            solve d
              (Base.List.cartesian_product
                 (part d (p.repr :: P.S.elements p.eqs))
                 (part d [ eq ]))
          with
          | `Solved ->
              (* The domains have been substituted, and possibly recursively *)
              merge_one_new_eq d cl eq
          | `Not_solved ->
              (* nothing to solve *)
              let repr = p.repr in
              let eqs = p.eqs |> P.S.add eq |> P.S.remove repr in
              let p = { repr; eqs } in
              add_used_product d cl (P.nodes eq);
              set_poly d cl p [ (eq, eq) ])
      else (
        add_used_product d cl (P.nodes eq);
        set_poly d cl { repr = eq; eqs = P.S.empty } [ (eq, eq) ])

    and subst d cl eq =
      Debug.dprintf5 debug "[Pivot:%s] subst %a with %a" P.name Node.pp cl P.pp
        eq;
      let po = Egraph.get_dom d dom cl in
      match po with
      | None ->
          let p = { repr = eq; eqs = P.S.empty } in
          add_used_product d cl (P.nodes eq);
          set_poly d cl p [ (eq, eq) ]
      | Some p ->
          assert (Option.equal Node.equal (P.is_one_node p.repr) (Some cl));
          subst_doms d cl eq

    and subst_doms d cl p =
      let b =
        match Node.HC.find used_in_poly d cl with
        | exception Not_found -> Node.S.empty
        | b -> b
      in
      let touched = Node.H.create 10 in
      Node.S.iter
        (fun cl' ->
          match Egraph.get_dom d dom cl' with
          | None -> assert false (* absurd: can't be used and absent *)
          | Some q ->
              let fold (new_cl, acc, chg) (q : P.t) =
                let new_cl =
                  Node.M.set_union new_cl
                    (Node.M.set_diff (P.nodes p) (P.nodes q))
                in
                match P.subst q cl p with
                | None -> (new_cl, P.S.add q acc, chg)
                | Some q' ->
                    Node.H.replace touched cl' ();
                    (new_cl, P.S.add q' acc, (q, q') :: chg)
              in
              let new_cl, acc, chg =
                fold (Node.M.empty, P.S.empty, []) q.repr
              in
              let repr = P.S.choose acc (* there is only one in acc *) in
              let new_cl, acc, chg =
                P.S.fold_left fold (new_cl, acc, chg) q.eqs
              in
              let eqs = P.S.remove repr acc in
              add_used_product d cl' new_cl;
              set_poly d cl' { repr; eqs } chg)
        b;
      Node.H.iter (recheck d) touched

    and part d l = List.map (fun p -> P.info d p) l

    and solve d l =
      let exception Solved of P.t Node.M.t in
      let criteria i1 i2 =
        let aux i1 i2 =
          match P.solve i1 i2 with
          | AlreadyEqual -> ()
          | Contradiction -> Egraph.contradiction d
          | Unsolved -> ()
          | Subst m -> raise (Solved m)
        in
        aux i1 i2;
        aux i2 i1
      in
      match List.iter (fun (a, b) -> criteria a b) l with
      | exception Solved m ->
          let n, p = Node.M.choose m in
          subst d n p;
          Node.M.iter (fun n p -> merge_one_new_eq d n p) (Node.M.remove n m);
          `Solved
      | () -> `Not_solved

    and recheck d n () =
      match Egraph.get_dom d dom n with
      | None -> assert false (* absurd: can't be used and absent *)
      | Some p -> (
          match
            solve d
              (Base.List.cartesian_product
                 (part d (p.repr :: P.S.elements p.eqs))
                 (part d (p.repr :: P.S.elements p.eqs)))
          with
          | `Solved -> recheck d n ()
          | `Not_solved -> ())

    let key = dom

    type nonrec t = t

    let pp = pp
  end

  let () = Dom.register (module Th)

  let get_repr d n =
    let open CCOption in
    let+ p = Egraph.get_dom d dom n in
    p.repr

  let iter_eqs d n ~f =
    match Egraph.get_dom d dom n with
    | None -> ()
    | Some p ->
        f p.repr;
        P.S.iter f p.eqs

  let assume_equality d n (p : P.t) =
    Debug.dprintf5 debug "[Pivot %s] assume %a = %a" P.name Node.pp n P.pp p;
    let n = Egraph.find d n in
    Th.merge_one_new_eq d n p

  let reshape d cl ~(f : P.t -> P.t option) =
    match Node.HC.find used_in_poly d cl with
    | exception Not_found -> ()
    | b ->
        let touched = Node.H.create 10 in
        Node.S.iter
          (fun cl' ->
            match Egraph.get_dom d dom cl' with
            | None -> assert false (* absurd: can't be used and absent *)
            | Some q ->
                let replace p =
                  match f p with
                  | None -> p
                  | Some p ->
                      Node.H.replace touched cl' ();
                      p
                in
                let eqs =
                  P.S.fold
                    (fun p acc -> P.S.add (replace p) acc)
                    q.eqs P.S.empty
                in
                let q' = { repr = replace q.repr; eqs } in
                Egraph.set_dom d dom cl' q';
                let l = Th.part d (q'.repr :: P.S.elements q'.eqs) in
                let l = Base.List.cartesian_product l l in
                ignore (Th.solve d l))
          b;
        Node.H.iter (Th.recheck d) touched

  module ChangeInfo = struct
    type runable = Node.S.t

    let print_runable = Node.S.pp

    let run d ns =
      Node.S.iter
        (fun n ->
          let p = Base.Option.value_exn (Egraph.get_dom d dom n) in
          let l = Th.part d (p.repr :: P.S.elements p.eqs) in
          let l = Base.List.cartesian_product l l in
          ignore (Th.solve d l))
        ns

    let delay = Events.Delayed_by 10

    let key =
      Events.Dem.create
        (module struct
          type t = Node.S.t

          let name = "Dom_product.ChangePos"
        end)

    let init d =
      P.attach_info_change d (fun d n ->
          match Node.HC.find_opt used_in_poly d n with
          | Some ns -> Events.EnqRun (key, ns, None)
          | None -> Events.EnqAlready)
  end

  let () = Events.register (module ChangeInfo)
  let init d = ChangeInfo.init d

  let attach_eqs_change d ?node f =
    match node with
    | Some x -> Daemon.attach_dom d x dom f
    | None -> Daemon.attach_any_dom d dom f

  let attach_repr_change = attach_eqs_change
end

type 'a solve_total = AlreadyEqual | Contradiction | Subst of Node.t * 'a

module Total (P : sig
  type t

  val name : string

  include Colibri2_popop_lib.Popop_stdlib.Datatype with type t := t

  val of_one_node : Node.t -> t
  val is_one_node : t -> Node.t option
  val subst : t -> Node.t -> t -> t
  val normalize : t -> f:(Node.t -> t) -> t

  type data

  val nodes : t -> data Node.M.t
  val solve : t -> t -> t solve_total
  val set : Egraph.wt -> Node.t -> old_:t option -> new_:t -> unit
end) : sig
  val assume_equality : Egraph.wt -> Node.t -> P.t -> unit
  val init : Egraph.wt -> unit
  val get_repr : _ Egraph.t -> Node.t -> P.t option

  val attach_repr_change :
    _ Egraph.t -> ?node:Node.t -> (Egraph.wt -> Node.t -> unit) -> unit

  val events_repr_change :
    _ Egraph.t ->
    ?node:Node.t ->
    (Egraph.rt -> Node.t -> Events.enqueue) ->
    unit

  val normalize : _ Egraph.t -> P.t -> P.t
end = struct
  open Colibri2_popop_lib

  let dom =
    Dom.Kind.create
      (module struct
        type t = P.t

        let name = P.name
      end)

  let used_in_poly : Node.t Bag.t Node.HC.t =
    Node.HC.create (Bag.pp Node.pp) "used_in_poly"

  let set_poly d cl old_ new_ =
    Egraph.set_dom d dom cl new_;
    P.set d cl ~old_ ~new_

  let add_used d cl' new_cl =
    Node.M.iter
      (fun used _ ->
        Node.HC.change
          (function
            | Some b -> Some (Bag.append b cl')
            | None ->
                (match Egraph.get_dom d dom used with
                | None ->
                    (* If a used node have no polynome associated, we set it to
                       itself. This allows to be warned when this node is merged. It
                       is the reason why this module doesn't specifically wait for
                       representative change *)
                    Egraph.set_dom d dom used (P.of_one_node used)
                | Some p -> assert (P.equal (P.of_one_node used) p));
                Some (Bag.elt cl'))
          used_in_poly d used)
      new_cl

  let subst_doms d cl (p : P.t) =
    let b =
      match Node.HC.find used_in_poly d cl with
      | exception Not_found -> Bag.empty
      | b -> b
    in
    Bag.iter
      (fun cl' ->
        match Egraph.get_dom d dom cl' with
        | None -> assert false (* absurd: can't be used and absent *)
        | Some q ->
            let new_cl = Node.M.set_diff (P.nodes p) (P.nodes q) in
            let q_new = P.subst q cl p in
            add_used d cl' new_cl;
            set_poly d cl' (Some q) q_new)
      b;
    add_used d cl (P.nodes p);
    set_poly d cl None p

  module Th = struct
    include P

    let merged v1 v2 =
      match (v1, v2) with
      | None, None -> true
      | Some v', Some v -> equal v' v
      | _ -> false

    let norm_dom cl = function
      | None ->
          let r = P.of_one_node cl in
          r
      | Some p -> p

    let add_itself d cl norm =
      add_used d cl (P.nodes norm);
      Egraph.set_dom d dom cl norm

    let merge d ((p1o, cl1) as a1) ((p2o, cl2) as a2) inv =
      assert (not (Egraph.is_equal d cl1 cl2));
      assert (not (Option.is_none p1o && Option.is_none p2o));
      let (pother, other), (prepr, repr) = if inv then (a2, a1) else (a1, a2) in
      let other = Egraph.find d other in
      let repr = Egraph.find d repr in
      let p1 = norm_dom other pother in
      let p2 = norm_dom repr prepr in
      (match P.solve p1 p2 with
      | AlreadyEqual -> (
          (* no new equality already equal *)
          match (pother, prepr) with
          | Some _, Some _ | None, None ->
              assert false (* absurd: no need of merge *)
          | Some p, None ->
              (* p = repr *)
              add_itself d repr p
          | None, Some p ->
              (* p = other *)
              add_itself d other p)
      | Contradiction -> Egraph.contradiction d
      | Subst (x, p) ->
          Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
          let add_if_default n norm = function
            | Some _ -> ()
            | None -> add_itself d n norm
          in
          add_if_default other p1 pother;
          add_if_default repr p2 prepr;
          subst_doms d x p);
      assert (
        Option.compare P.compare
          (Egraph.get_dom d dom repr)
          (Egraph.get_dom d dom other)
        = 0)

    let solve_one d cl p1 =
      let p2 = Egraph.get_dom d dom cl in
      if Option.is_some p2 || Node.M.mem cl (P.nodes p1) then (
        let p2 = norm_dom cl p2 in
        match P.solve p1 p2 with
        | AlreadyEqual -> ()
        | Contradiction -> Egraph.contradiction d
        | Subst (x, p) ->
            Debug.dprintf2 debug "[Arith] @[pivot %a@]" Node.pp x;
            subst_doms d x p)
      else
        (* This case allows to not substitute when not needed *)
        subst_doms d cl p1

    let key = dom
  end

  let () = Dom.register (module Th)

  let normalize d (p : P.t) =
    P.normalize p ~f:(fun cl ->
        let cl = Egraph.find_def d cl in
        match Egraph.get_dom d dom cl with
        | None -> P.of_one_node cl
        | Some p -> p)

  let assume_equality d n (p : P.t) =
    let n = Egraph.find_def d n in
    let p = normalize d p in
    Th.solve_one d n p

  let get_repr d cl = Egraph.get_dom d dom cl

  let attach_repr_change d ?node f =
    match node with
    | Some x -> Daemon.attach_dom d x dom f
    | None -> Daemon.attach_any_dom d dom f

  let events_repr_change d ?node f =
    match node with
    | Some x -> Events.attach_dom d x dom f
    | None -> Events.attach_any_dom d dom f

  let init _ = ()
end