(*************************************************************************)
(*  This file is part of Colibri2.                                       *)
(*                                                                       *)
(*  Copyright (C) 2014-2021                                              *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*    OCamlPro                                                           *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
(*************************************************************************)

open Colibri2_core
open Colibri2_popop_lib
open Common
open Colibri2_theories_quantifiers
open RWRules

(*
  Command line options:
  - "None": uses RW1(adown), RW2(aup), idx and extensionality
  - "--no-wegraph": don't use the weak equivalency graph
  - "--no-res-ext": don't restrict the extrensionality rule using the foreign
    domain
  - "--no-res-aup": don't restrict the RW2 rule using the linearity domain
  - "--array-ext-comb": to support additional combinators
    (const, map, def_ind, def_val)
  - "--array-blast-rule": uses the blast rule when it suits
  - "--array-def-values": suppots the rules on the default values
*)

let converter env (f : Ground.t) =
  let s = Ground.sem f in
  let fn = Ground.node f in
  let f_is_array =
    match s.ty with
    | { app = { builtin = Expr.Array; _ }; args = [ ind_gty; val_gty ]; _ } ->
        add_array_gty env fn ind_gty val_gty;
        Id.Array.set_id env fn;
        new_array env ind_gty val_gty f;
        true
    | _ -> false
  in
  match s with
  | { app = { builtin = Expr.Base; id_ty; _ }; args; tyargs; _ } ->
      if IArray.is_empty args then (
        if (not (Options.get env no_res_aup)) && f_is_array then
          (* update of the Linearity domain *)
          Linearity_dom.upd_dom env fn Empty)
      else if not (Options.get env no_res_ext) then (
        (* update of the Foreign domain *)
        let subst, arg_tys =
          match id_ty.ty_descr with
          | Pi (tyvl, { ty_descr = Expr.Arrow (ptys, _); _ }) ->
              ( List.fold_left2
                  (fun m k v -> Expr.Ty.Var.M.add k v m)
                  Expr.Ty.Var.M.empty tyvl tyargs,
                ptys )
          | Expr.Arrow (ptys, _) -> (Expr.Ty.Var.M.empty, ptys)
          | _ -> (Expr.Ty.Var.M.empty, [])
        in
        assert (arg_tys <> []);
        IArray.iteri
          ~f:(fun i n ->
            let gty = Ground.Ty.convert subst (List.nth arg_tys i) in
            match gty with
            | { app = { builtin = Expr.Array; _ }; _ } ->
                Id.Array.set_id env n;
                Ground.add_ty env n gty;
                Foreign_dom.set_dom env gty n IsForeign
            | _ -> ())
          args)
  | {
   app = { builtin = Expr.Select; _ };
   args;
   tyargs = [ ind_gty; val_gty ];
   _;
  } ->
      Array_value.propagate_value env f;
      let a, i = IArray.extract2_exn args in
      Egraph.register env a;
      Egraph.register env i;
      add_array_gty env a ind_gty val_gty;
      Id.Value.set_id env fn;
      Id.Array.set_id env a;
      Id.Index.set_id env i;
      Ground.add_ty env i ind_gty;
      if not (Options.get env no_wegraph) then WEGraph.new_select env fn a i;
      (* update of the Foreign domain *)
      if (not (Options.get env no_res_ext)) && ind_gty.app.builtin == Expr.Array
      then
        (* id and ground type are set during registration *)
        Foreign_dom.set_dom env (Ground.Ty.array ind_gty val_gty) i IsForeign;
      if Options.get env extended_comb then (
        (* when a new read is encountered, check if map⇑ can be applied *)
        Array_dom.add_read env a i;
        (* 𝝐≠: v = a[i], i is not 𝝐 |> i ≠ 𝝐 or blast *)
        let eps_node =
          ground_apply env Builtin.array_default_index [ ind_gty; val_gty ]
            [ a ]
        in
        Egraph.register env eps_node;
        if not (Egraph.is_equal env i eps_node) then (
          let ind_gty, _ = array_gty_args ind_gty in
          if check_gty_num_size env ind_gty then
            (* application of the blast rule *)
            apply_blast_rule env a ind_gty val_gty
          else
            (* application of 𝝐≠ *)
            let i_eps_neq_node = Equality.disequality env [ eps_node; i ] in
            Egraph.register env i_eps_neq_node;
            Boolean.set_true env i_eps_neq_node))
  | {
   app = { builtin = Expr.Store; _ };
   args;
   tyargs = [ ind_gty; val_gty ];
   _;
  } ->
      Array_value.propagate_value env f;
      let a = Ground.node f in
      let b, k, v = IArray.extract3_exn args in
      Egraph.register env b;
      Egraph.register env k;
      Egraph.register env v;
      add_array_gty env a ind_gty val_gty;
      add_array_gty env b ind_gty val_gty;
      Ground.add_ty env v val_gty;
      Ground.add_ty env k ind_gty;
      Id.Array.set_id env a;
      Id.Array.set_id env b;
      Id.Index.set_id env k;
      Id.Value.set_id env v;
      (* update of the Linearity domain *)
      if not (Options.get env no_res_aup) then
        Linearity_dom.upd_dom env fn (Linear b);
      (* application of the `idx` rule *)
      let rn = Equality.equality env [ mk_select env a k ind_gty val_gty; v ] in
      Egraph.register env rn;
      Boolean.set_true env rn;
      (* application of the `U𝛿` rule *)
      if Options.get env default_values then (
        let eq_node =
          Equality.equality env
            [
              ground_apply env Builtin.array_default_value [ ind_gty; val_gty ]
                [ a ];
              ground_apply env Builtin.array_default_value [ ind_gty; val_gty ]
                [ b ];
            ]
        in
        Egraph.register env eq_node;
        Boolean.set_true env eq_node);
      if not (Options.get env no_wegraph) then WEGraph.new_store env a b k v
  | {
   app = { builtin = Builtin.Array_diff; _ };
   args;
   tyargs = [ ind_gty; val_gty ];
   _;
  } ->
      Array_value.propagate_value env f;
      let a, b = IArray.extract2_exn args in
      Egraph.register env a;
      Egraph.register env b;
      Id.Array.set_id env a;
      Id.Array.set_id env b;
      add_array_gty env a ind_gty val_gty;
      add_array_gty env b ind_gty val_gty
  | {
   app = { builtin = Builtin.Array_const; _ };
   args;
   tyargs = [ ind_gty; val_gty ];
   _;
  } ->
      Array_value.propagate_value env f;
      let v = IArray.extract1_exn args in
      Egraph.register env v;
      (* application of the `K𝛿` rule *)
      if Options.get env default_values then (
        (* TODO: make a separate array node and set it's type? *)
        let const_n = mk_array_const env v val_gty in
        let defv_n =
          ground_apply env Builtin.array_default_value
            [ Ground.Ty.array ind_gty val_gty ]
            [ const_n ]
        in
        let eq_node = Equality.equality env [ defv_n; v ] in
        Egraph.register env eq_node;
        Boolean.set_true env eq_node)
  | {
   app = { builtin = Builtin.Array_map; _ };
   args;
   tyargs = [ bi_ind_ty; bi_val_ty; a_val_ty ];
   _;
  }
    when Options.get env extended_comb || Options.get env default_values ->
      (if Options.get env extended_comb then
       let f_arity = IArray.length args - 1 in
       IArray.iteri args ~f:(fun i n ->
           if i > 0 then (
             Id.Array.set_id env n;
             add_array_gty env n bi_ind_ty bi_val_ty;
             Array_dom.add_map_parent env n f
               { bi_ind_ty; bi_val_ty; a_val_ty; f_arity })));
      new_map env f
  | _ -> ()

let init env =
  Array_value.init_ty env;
  Array_value.init_check env;
  Ground.register_converter env converter;
  if not (Options.get env no_wegraph) then (
    Id.Array.register_new_id_hook env WEGraph.new_id_hook;
    Id.Array.register_merge_hook env WEGraph.eq_arrays_norm;
    Id.Index.register_new_id_hook env WEGraph.new_index_id_hook;
    Id.Index.register_merge_hook env WEGraph.eq_indices_norm;
    Id.Value.register_new_id_hook env WEGraph.new_value_id_hook;
    Id.Value.register_merge_hook env WEGraph.eq_values_norm;
    Equality.register_hook_new_disequality env WEGraph.ineq_indices_norm);
  (* extᵣ (restricted extensionality):
     - (a = b) ≡ false |> (a[k] ≠ b[k])
     - a, b, {a,b} ⊆ foreign |> (a = b) ⋁ (a[k] ≠ b[k]) *)
  if not (Options.get env no_res_ext) then (
    (* if Options.get env use_wegraph then (
         (* (a = b) ≡ false |> (a[k] ≠ b[k]) (when a and b are neighbours) *)
         Equality.register_hook_new_disequality env apply_res_ext_1_2;
         (* a, b, {a,b} ⊆ foreign |> (a = b) ⋁ (a[k] ≠ b[k])
            (when a and b are neighbours) *)
         Foreign_dom.register_hook_new_foreign_array env apply_res_ext_2_2)
       else *)
    (* (a = b) ≡ false |> (a[k] ≠ b[k]) *)
    Equality.register_hook_new_disequality env apply_res_ext_1_1;
    (* a, b, {a,b} ⊆ foreign |> (a = b) ⋁ (a[k] ≠ b[k]) *)
    Foreign_dom.register_hook_new_foreign_array env apply_res_ext_2_1);
  let l = [ (adown_pattern, adown_run) ] in
  let l =
    if not (Options.get env no_res_aup) then (raup_pattern, raup_run) :: l
    else (aup_pattern, aup_run) :: l
  in
  let l =
    if Options.get env extended_comb then (
      Array_dom.register_hook_new_map_parent env add_array_map_hook;
      Array_dom.register_hook_new_read env add_array_read_hook;
      (const_read_pattern, const_read_run) :: l)
    else l
  in
  List.iter (fun (p, r) -> InvertedPath.add_callback env p r) l

let () = Init.add_default_theory init