(* -------------------------------------------------------------------------- *)
(* --- Dependencies of Logic Definitions                                  --- *)
(* -------------------------------------------------------------------------- *)

open Cil
open Cil_types
open Cil_datatype
open Clabels
open Visitor

(* -------------------------------------------------------------------------- *)
(* --- Name Utilities                                                     --- *)
(* -------------------------------------------------------------------------- *)

let trim name =
  let rec first s k n =
    if k < n && s.[k]='_' then first s (succ k) n else k in
  let rec last s k =
    if k >= 0 && s.[k]='_' then last s (pred k) else k in
  let n = String.length name in
  if n > 0 then
    if ( name.[0]='_' || name.[n-1]='_' ) then
      let p = first name 0 n in
      let q = last name (pred n) in
      if p <= q then
        let name = String.sub name p (q+1-p) in
        match name.[0] with
        | '0' .. '9' -> "_" ^ name
        | _ -> name
      else "_"
    else name
  else "_"

(* -------------------------------------------------------------------------- *)
(* --- Definition Blocks                                                  --- *)
(* -------------------------------------------------------------------------- *)

type logic_lemma = {
  lem_name : string ;
  lem_position : Filepath.position ;
  lem_types : string list ;
  lem_labels : logic_label list ;
  lem_predicate : toplevel_predicate ;
  lem_depends : logic_lemma list ;
  (* global lemmas declared before in AST order (in reverse order) *)
  lem_attrs : attributes ;

type axiomatic = {
  ax_name : string ;
  ax_position : Filepath.position ;
  ax_property : Property.t ;
  mutable ax_types : logic_type_info list ;
  mutable ax_logics : logic_info list ;
  mutable ax_lemmas : logic_lemma list ;
  mutable ax_reads : Varinfo.Set.t ; (* read-only *)

type logic_section =
  | Toplevel of int
  | Axiomatic of axiomatic

let is_global_axiomatic ax =
  ax.ax_types = [] &&
  ax.ax_logics = [] &&
  ax.ax_lemmas <> []

module SMap = Datatype.String.Map
module TMap = Logic_type_info.Map
module LMap = Logic_info.Map
module LSet = Logic_info.Set

(* -------------------------------------------------------------------------- *)
(* --- Usage and Dependencies                                             --- *)
(* -------------------------------------------------------------------------- *)

type inductive_case = {
  ind_logic : logic_info ;
  ind_case : string ;
  mutable ind_call : LabelSet.t LabelMap.t ;

type database = {
  mutable cases : inductive_case list LMap.t ;
  mutable clash : LSet.t SMap.t ;
  mutable names : string LMap.t ;
  mutable types : logic_section TMap.t ;
  mutable logics : logic_section LMap.t ;
  mutable lemmas : (logic_lemma * logic_section) SMap.t ;
  mutable recursives : LSet.t ;
  mutable axiomatics : axiomatic SMap.t ;
  mutable proofcontext : logic_lemma list ;

let empty_database () = {
  cases = LMap.empty ;
  names = LMap.empty ;
  clash = SMap.empty ;
  types = TMap.empty ;
  logics = LMap.empty ;
  lemmas = SMap.empty ;
  recursives = LSet.empty ;
  axiomatics = SMap.empty ;
  proofcontext = [] ;

module DatabaseType = Datatype.Make
      type t = database
      include Datatype.Serializable_undefined
      let reprs = [empty_database ()]
      let name = "Wp.LogicUsage.DatabaseType"

module Database = State_builder.Ref(DatabaseType)
      let name = "Wp.LogicUsage.Database"
      let dependencies = [Ast.self;Annotations.code_annot_state]
      let default = empty_database

let pp_logic fmt l = Printer.pp_logic_var fmt l.l_var_info

(* -------------------------------------------------------------------------- *)
(* --- Overloading                                                        --- *)
(* -------------------------------------------------------------------------- *)

let basename x = trim x.vorig_name

let compute_logicname l =
  let d = Database.get () in
  try LMap.find l d.names
  with Not_found ->
    let base = l.l_var_info.lv_name in
    let over =
      try SMap.find base d.clash
      with Not_found -> LSet.empty (*TODO: Undetected usage -> overloading issue *)
    match LSet.elements over with
    | [] | [_] -> d.names <- LMap.add l base d.names ; base
    | symbols ->
        let rec register k = function
          | l::ls ->
              let name = Printf.sprintf "%s_%d_" base k in
              d.names <- LMap.add l name d.names ;
              register (succ k) ls
          | [] -> ()
        in register 1 symbols ; LMap.find l d.names

let is_overloaded l =
  let d = Database.get () in
  try LSet.cardinal (SMap.find l.l_var_info.lv_name d.clash) > 1
  with Not_found -> false

let pp_profile fmt l =
  Format.fprintf fmt "%s" l.l_var_info.lv_name ;
  match l.l_profile with
  | [] -> ()
  | x::xs ->
      Format.fprintf fmt "@[<hov 1>(%a" Printer.pp_logic_type x.lv_type ;
        (fun y -> Format.fprintf fmt ",@,%a"
            Printer.pp_logic_type y.lv_type)
        xs ;
      Format.fprintf fmt ")@]"

(* -------------------------------------------------------------------------- *)
(* --- Utilities                                                          --- *)
(* -------------------------------------------------------------------------- *)

let ip_lemma l =
  Property.ip_lemma {
    il_name = l.lem_name; il_labels = l.lem_labels;
    il_args = l.lem_types; il_loc = (l.lem_position, l.lem_position);
    il_attrs = l.lem_attrs;
    il_pred = l.lem_predicate;
let lemma_of_global ~context = function
  | Dlemma(name,labels,types,pred,attrs,loc) ->
        lem_name = name ;
        lem_position = fst loc ;
        lem_types = types ;
        lem_labels = labels ;
        lem_predicate = pred ;
        lem_depends = context ;
        lem_attrs = attrs ;
let populate a ~context = function
  | Dfun_or_pred(l,_) -> a.ax_logics <- l :: a.ax_logics
  | Dtype(t,_) -> a.ax_types <- t :: a.ax_types
  | Dlemma _ as g -> a.ax_lemmas <- lemma_of_global ~context g :: a.ax_lemmas
  | _ -> ()

let ip_of_axiomatic g =
  match Property.ip_of_global_annotation_single g with
  | None -> assert false
  | Some ip -> ip

let axiomatic_of_global ~context = function
  | Daxiomatic(name,globals,_,loc) as g ->
      let a = {
        ax_name = name ;
        ax_position = fst loc ;
        ax_property = ip_of_axiomatic g ;
        ax_reads = Varinfo.Set.empty ;
        ax_types = [] ; ax_lemmas = [] ; ax_logics = [] ;
      } in
      List.iter (populate a ~context) globals ;
      a.ax_types <- List.rev a.ax_types ;
      a.ax_logics <- List.rev a.ax_logics ;
      a.ax_lemmas <- List.rev a.ax_lemmas ;
  | _ -> assert false

let register_logic d section l =
  let name = l.l_var_info.lv_name in
  let over =
    try LSet.add l (SMap.find name d.clash)
    with Not_found -> LSet.singleton l in
    d.clash <- SMap.add name over d.clash ;
    d.logics <- LMap.add l section d.logics ;

let register_lemma d section l =
    d.lemmas <- SMap.add l.lem_name (l,section) d.lemmas ;

let register_type d section t =
    d.types <- TMap.add t section d.types ;

let register_axiomatic d a =
    d.axiomatics <- SMap.add a.ax_name a d.axiomatics ;

let register_cases l inds =
  let d = Database.get () in
  d.cases <- LMap.add l inds d.cases

(* -------------------------------------------------------------------------- *)
(* --- Adding a label called in an inductive case                         --- *)
(* -------------------------------------------------------------------------- *)

(* calls : LabelSet.t LabelMap.t
   Given an inductive phi{...A...}
   In case H{...B...}, have a call to phi{...B...}
   Then: ( A \in calls[B] ).

let add_call calls l_a l_b =
  let a = Clabels.of_logic l_a in
  let b = Clabels.of_logic l_b in
  let s =
    try LabelSet.add a (LabelMap.find b calls)
    with Not_found -> LabelSet.singleton a
  LabelMap.add b s calls

(* -------------------------------------------------------------------------- *)
(* --- Visitor                                                            --- *)
(* -------------------------------------------------------------------------- *)

class visitor =

    inherit Visitor.frama_c_inplace

    val database = Database.get ()
    val mutable caller : logic_info option = None
    val mutable axiomatic : axiomatic option = None
    val mutable inductive : inductive_case option = None
    val mutable toplevel = 0

    method private section =
      match axiomatic with
      | None -> Toplevel toplevel
      | Some a -> Axiomatic a

    method private do_var x =
      match axiomatic with
      | None -> ()
      | Some a -> a.ax_reads <- Varinfo.Set.add x a.ax_reads

    method private do_lvar x =
      try self#do_call (Logic_env.find_logic_cons x) []
      with Not_found -> ()

    method private do_call l labels =
      match inductive with
      | Some case ->
          if Logic_info.equal l case.ind_logic then
            case.ind_call <- List.fold_left2 add_call case.ind_call l.l_labels labels
      | None ->
          match caller with
          | None -> ()
          | Some f ->
              if Logic_info.equal f l then
                database.recursives <- LSet.add f database.recursives

    method private do_case l (case,_labels,_types,pnamed) =
        let indcase = {
          ind_logic = l ;
          ind_case = case ;
          ind_call = LabelMap.empty ;
        } in
        inductive <- Some indcase ;
        ignore (visitFramacPredicate (self :> frama_c_visitor) pnamed) ;
        inductive <- None ; indcase

    (* --- LVALUES --- *)

    method! vlval = function
      | (Var x,_) -> self#do_var x ; DoChildren
      | _ -> DoChildren

    method! vterm_lval = function
      | (TVar { lv_origin=Some x } , _ ) -> self#do_var x ; DoChildren
      | (TVar x , _ ) -> self#do_lvar x ; DoChildren
      | _ -> DoChildren

    (* --- TERMS --- *)

    method! vterm_node = function
      | Tapp(l,labels,_) -> self#do_call l labels ; DoChildren
      | _ -> DoChildren

    (* --- PREDICATE --- *)

    method! vpredicate_node = function
      | Papp(l,labels,_) -> self#do_call l labels ; DoChildren
      | _ -> DoChildren

    method! vannotation global =
      match global with

      (* --- AXIOMATICS --- *)

      | Daxiomatic _ ->
            let pf = database.proofcontext in
            let ax = axiomatic_of_global pf global in
            register_axiomatic database ax ;
            axiomatic <- Some ax ;
              (fun g ->
                 if not (is_global_axiomatic ax) then
                   database.proofcontext <- pf ;
                 axiomatic <- None ;
                 toplevel <- succ toplevel ;

      (* --- LOGIC INFO --- *)

      | Dtype_annot(l,_)
      | Dinvariant(l,_)
      | Dfun_or_pred(l,_) ->
            register_logic database self#section l ;
            match l.l_body with
            | LBnone when axiomatic = None -> SkipChildren

            | LBnone | LBreads _ | LBterm _ | LBpred _ ->
                caller <- Some l ;
                DoChildrenPost (fun g -> caller <- None ; g)

            | LBinductive cases ->
                register_cases l ( (self#do_case l) cases) ;

      (* --- LEMMAS --- *)

      | Dlemma _ ->
          let lem = lemma_of_global database.proofcontext global in
          register_lemma database self#section lem ;
          if Logic_utils.use_predicate lem.lem_predicate.tp_kind then
            database.proofcontext <- lem :: database.proofcontext ;

      | Dtype(t,_) ->
          register_type database self#section t ;

      (* --- OTHERS --- *)

      | Dvolatile _
      | Dmodel_annot _
      | Dextended _
        -> SkipChildren

    method! vfunc _ = SkipChildren


let compute () = ~ontty:`Transient "Collecting axiomatic usage" ;
  Visitor.visitFramacFile (new visitor) (Ast.get ())

(* -------------------------------------------------------------------------- *)
(* --- External API                                                       --- *)
(* -------------------------------------------------------------------------- *)

let (compute,_) =
  State_builder.apply_once "LogicUsage.compute"
    [Ast.self;Annotations.code_annot_state] compute

let is_recursive l =
  compute () ;
  let d = Database.get () in
  LSet.mem l d.recursives

let get_induction_labels l case =
  compute () ;
    let d = Database.get () in
    let cases = LMap.find l d.cases in
    try (List.find (fun i -> i.ind_case = case) cases).ind_call
    with Not_found ->
      Wp_parameters.fatal "No case '%s' for inductive '%s'"
        case l.l_var_info.lv_name
  with Not_found ->
    Wp_parameters.fatal "Non-inductive '%s'" l.l_var_info.lv_name

let axiomatic a =
  compute () ;
    let d = Database.get () in
    SMap.find a d.axiomatics
  with Not_found ->
    Wp_parameters.fatal "Axiomatic '%s' undefined" a

let section_of_type t =
  compute () ;
    let d = Database.get () in
    TMap.find t d.types
  with Not_found ->
    Wp_parameters.fatal "Logic type '%s' undefined" t.lt_name

let section_of_logic l =
  compute () ;
    let d = Database.get () in
    LMap.find l d.logics
  with Not_found ->
    Wp_parameters.fatal "Logic '%a' undefined" pp_logic l

let get_lemma l =
  compute () ;
    let d = Database.get () in
    SMap.find l d.lemmas
  with Not_found ->
    Wp_parameters.fatal "Lemma '%s' undefined" l

let iter_lemmas f =
  compute () ;
  let d = Database.get () in
  SMap.iter (fun _name (lem,_) -> f lem) d.lemmas

let fold_lemmas f =
  compute () ;
  let d = Database.get () in
  SMap.fold (fun _name (lem,_) -> f lem) d.lemmas

let logic_lemma l = fst (get_lemma l)

let section_of_lemma l = snd (get_lemma l)

let proof_context () =
  (* No need for compute: if no lemma, database is empty ! *)
  let d = Database.get () in

(* -------------------------------------------------------------------------- *)
(* --- Dump API                                                           --- *)
(* -------------------------------------------------------------------------- *)

let pp_type fmt t = Format.fprintf fmt " * type '%s'@\n" t.lt_name
let pp_sig fmt kind l =
    Format.fprintf fmt " * %s '%s'@\n" kind (compute_logicname l) ;
    if is_overloaded l then
      Format.fprintf fmt "   profile %a@\n" pp_profile l ;
    if is_recursive l then
      Format.fprintf fmt "   recursive@\n" ;

let pp_decl fmt d l =
      let cases = LMap.find l d.cases in
      pp_sig fmt "inductive" l ;
        (fun ind ->
           Format.fprintf fmt "   @[case %s:" ind.ind_case ;
             (fun l s ->
                Format.fprintf fmt "@ @[<hov 2>{%a:" Clabels.pretty l ;
                LabelSet.iter (fun l -> Format.fprintf fmt "@ %a"
                                  Clabels.pretty l) s ;
                Format.fprintf fmt "}@]"
             ) ind.ind_call ;
           Format.fprintf fmt "@]@\n"
        ) cases ;
    with Not_found ->
      let kind = if l.l_type = None then "predicate" else "function" in
      pp_sig fmt kind l ;
let pp_lemma fmt l =
  Format.fprintf fmt " * %a '%s'@\n"
    Cil_printer.pp_lemma_kind l.lem_predicate.tp_kind l.lem_name

let get_name l = compute () ; compute_logicname l

let pp_section fmt = function
  | Toplevel 0 -> Format.fprintf fmt "Toplevel"
  | Toplevel n -> Format.fprintf fmt "Toplevel(%d)" n
  | Axiomatic a -> Format.fprintf fmt "Axiomatic '%s'" a.ax_name

let dump () =
  compute () ;
    begin fun fmt ->
      let d = Database.get () in
        (fun _ a ->
           Format.fprintf fmt "Axiomatic %s {@\n" a.ax_name ;
           List.iter (pp_type fmt) a.ax_types ;
           List.iter (pp_decl fmt d) a.ax_logics ;
           List.iter (pp_lemma fmt) a.ax_lemmas ;
           Format.fprintf fmt "}@\n"
        ) d.axiomatics ;
        (fun t s ->
           Format.fprintf fmt " * type '%s' in %a@\n"
             t.lt_name pp_section s)
        d.types ;
        (fun l s ->
           Format.fprintf fmt " * logic '%a' in %a@\n"
             pp_logic l pp_section s)
        d.logics ;
        (fun l (lem,s) ->
           Format.fprintf fmt " * %a '%s' in %a@\n"
             Cil_printer.pp_lemma_kind lem.lem_predicate.tp_kind
             l pp_section s)
        d.lemmas ;
      Format.fprintf fmt "-------------------------------------------------@." ;