type 'a or_top_bottom = 'a Bottom.Top.or_top_bottom

let (>>>-:) t f = match t with
  | `Top -> `Top
  | `Bottom  -> `Bottom

module Callstack = Value_types.Callstack

type callstack = Callstack.t
type 'a by_callstack = (callstack * 'a) list

type control_point =
  | Initial
  | Final
  | Start of Cil_types.kernel_function
  | End of Cil_types.kernel_function
  | Before of Cil_types.stmt
  | After of Cil_types.stmt

type request = {
  control_point : control_point;
  selector : callstack list option;
  filter: (callstack -> bool) list;

type error = Bottom | Top | DisabledDomain
type 'a result = ('a,error) Result.t

let string_of_error = function
  | Bottom -> "The computed state is bottom"
  | Top -> "The computed state is Top"
  | DisabledDomain -> "The required domain is disabled"

let pretty_error fmt error =
  Format.pp_print_string fmt (string_of_error error)
let pretty_result f fmt r =
  Result.fold ~ok:(f fmt) ~error:(pretty_error fmt) r

(* Building requests *)

let make control_point = {
  selector = None;

let before stmt = make (Before stmt)
let after stmt = make (After stmt)
let at_start_of kf = make (Start kf)
let at_end_of kf = make (End kf)
let at_start = make Initial
let at_end = make Final

let before_kinstr = function
  | Cil_types.Kglobal -> at_start
  | Kstmt stmt -> before stmt

let in_callstacks l req = { req with selector = Some l }
let in_callstack cs req = { req with selector = Some [cs] }
let filter_callstack f req = { req with filter = f :: req.filter }

(* Manipulating request results *)

type restricted_to_callstack
type unrestricted_response
module Response =

  type ('a, 'callstack) t =
    | Consolidated : 'a -> ('a, unrestricted_response) t
    | ByCallstack  : 'a by_callstack -> ('a, 'callstack) t
    | Top : ('a, 'callstack) t
    | Bottom : ('a, 'callstack) t

  let coercion : ('a, restricted_to_callstack) t -> ('a, 'c) t = function
    | ByCallstack c -> ByCallstack c
    | Top -> Top
    | Bottom -> Bottom

  (* Constructors *)

  let consolidated =
    | `Bottom -> Bottom
    | `Value state -> Consolidated state

  let singleton cs =
    | `Bottom -> Bottom
    | `Value state -> ByCallstack [cs,state]

  let by_callstack : request ->
    [< `Bottom | `Top | `Value of 'a Value_types.Callstack.Hashtbl.t ] ->
    ('a, restricted_to_callstack) t =
    fun req table ->
    match table with
    | `Top -> Top
    | `Bottom -> Bottom
    | `Value table ->
      (* Filter *)
      let add cs state acc =
        if List.for_all (fun filter -> filter cs) req.filter
        then (cs, state) :: acc
        else acc
      (* Selection *)
      let l =
        match req.selector with
        | None -> Callstack.Hashtbl.fold add table []
            match Callstack.Hashtbl.find_opt table cs with
            | Some state -> add cs state acc
  let callstacks : ('a, restricted_to_callstack) t -> callstack list = function
    | Top | Bottom -> [] (* What else to do when Top is given ? *)
    | ByCallstack l -> fst l

  (* Iter *)

  let iter (f  : callstack -> 'a -> unit) :
    ('a, restricted_to_callstack) t -> unit =
    | Top | Bottom -> () (* What else to do when Top is given ? *)
    | ByCallstack l -> List.iter (fun (cs,x) -> f cs x) l

  (* Fold *)

  let fold (f  : callstack -> 'a -> 'b -> 'b) (acc : 'b) :
    ('a, restricted_to_callstack) t -> 'b =
    | Top | Bottom -> acc (* What else to do when Top is given ? *)
    | ByCallstack l -> List.fold_left (fun acc (cs,x) -> f cs x acc) acc l
  let map : type c. ('a -> 'b) -> ('a, c) t -> ('b, c) t = fun f -> function
    | Consolidated v -> Consolidated (f v)
    | ByCallstack l -> ByCallstack ( (fun (cs,x) -> cs,f x) l)
    | Top -> Top
    | Bottom -> Bottom
  let map_reduce : type c. ([`Top | `Bottom] -> 'b) -> ('a -> 'b) ->
    ('b -> 'b -> 'b) -> ('a, c) t -> 'b =
    fun default map reduce -> function
      | Consolidated v -> map v
      | ByCallstack ((_,h) :: t) ->
        List.fold_left (fun acc (_,x) -> reduce acc (map x)) (map h) t
      | ByCallstack [] | Bottom -> default `Bottom
      | Top -> default `Top

  let map_join : type c.
    ('a -> 'b) -> ('b -> 'b -> 'b) -> ('a, c) t -> 'b or_top_bottom =
    fun map join ->
    let default = function
      | `Bottom -> `Bottom
      | `Top -> `Top
    and map' x =
      `Value (map x)
    map_reduce default map' (Bottom.Top.join join)

  let map_join' : type c. ('a -> 'b or_top_bottom) -> ('b -> 'b -> 'b) ->
    ('a, c) t -> 'b or_top_bottom =
    fun map join ->
    let default = function
      | `Bottom -> `Bottom
      | `Top -> `Top
    and map' = (map :> 'a -> 'b or_top_bottom) in
    map_reduce default map' (Bottom.Top.join join)

(* Extracting states and values *)

module Make () =
  module A = (val Analysis.current_analyzer ())
  module EvalTypes =
    type valuation = A.Eval.Valuation.t
    type exp = (valuation * A.Val.t) Eval.evaluated
    type lval = (valuation * A.Val.t Eval.flagged_value) Eval.evaluated
    type loc = (valuation * A.Loc.location * Cil_types.typ) Eval.evaluated
  type ('a,'c) evaluation =
    | LValue: (EvalTypes.lval, 'c) Response.t -> (value,'c) evaluation
    | Value: (EvalTypes.exp, 'c) Response.t -> (value,'c) evaluation
    | Address: (EvalTypes.loc, 'c) Response.t * Cil_types.lval ->
        (address,'c) evaluation
  let rec get_by_callstack (req : request) :
    (_, restricted_to_callstack) Response.t =
    let open Response in
    match req.control_point with
    | Before stmt ->
      A.get_stmt_state_by_callstack ~after:false stmt |> by_callstack req
    | After stmt ->
      A.get_stmt_state_by_callstack ~after:true stmt |> by_callstack req
    | Initial ->
      A.get_kinstr_state ~after:false Kglobal |> singleton []
    | Start kf ->
      A.get_initial_state_by_callstack kf |> by_callstack req
    | End kf ->
      let stmt = Kernel_function.find_return kf in
      { req with control_point=After stmt } |> get_by_callstack
    | Final ->
      let main, _lib_entry = Globals.entry_point () in
      { req with control_point=End main } |> get_by_callstack

  let rec get (req : request) : (_, unrestricted_response) Response.t =
    if req.filter <> [] || Option.is_some req.selector then
      Response.coercion @@ get_by_callstack req
      let open Response in
      match req.control_point with
      | Before stmt ->
        A.get_stmt_state ~after:false stmt |> consolidated
      | After stmt ->
        A.get_stmt_state ~after:true stmt |> consolidated
      | End kf ->
        let stmt = Kernel_function.find_return kf in
        { req with control_point=After stmt } |> get
      | Final ->
        let main, _lib_entry = Globals.entry_point () in
        { req with control_point=End main } |> get
        Response.coercion @@ get_by_callstack req

  let convert : 'a or_top_bottom -> 'a result = function
    | `Top -> Result.error Top
    | `Bottom -> Result.error Bottom
    | `Value v -> Result.ok v

  let callstacks req =
    get_by_callstack req |> Response.callstacks

  let iter_callstacks f req =
    let f' cs _res =
      f cs (in_callstack cs req)
    get_by_callstack req |> Response.iter f'

  let fold_callstacks f acc req =
    let f' cs _res acc =
      f cs (in_callstack cs req) acc
    get_by_callstack req |> Response.fold f' acc

  let by_callstack req =
    let f cs _res acc =
      (cs, in_callstack cs req) :: acc
    get_by_callstack req |> Response.fold f []
  let is_reachable req =
    match get req with
    | Bottom -> false
    | _ -> true

  let equality_class exp req =
    let open Equality in
    match A.Dom.get Equality_domain.key with
    | None ->
      Result.error DisabledDomain
    | Some extract ->
      let hce = Hcexprs.HCE.of_exp exp in
      let extract' state =
        let equalities = Equality_domain.project (extract state) in
        NonTrivial (Set.find hce equalities)
      and reduce e1 e2 =
        match e1, e2 with
        | Trivial, _ | _, Trivial -> Trivial
        | NonTrivial e1, NonTrivial e2 -> Equality.inter e1 e2
      let r = match Response.map_join extract' reduce (get req) with
        | (`Top | `Bottom) as r -> r
        | `Value Trivial -> `Top
        | `Value (NonTrivial e) ->
          let l = Equality.elements e in
          `Value ( Hcexprs.HCE.to_exp l)
      convert r

  let as_cvalue_model req =
    match A.Dom.get Cvalue_domain.State.key with
    | None ->
      Result.error DisabledDomain
    | Some extract ->
      let extract' state =
        fst (extract state)
      convert (Response.map_join extract' Cvalue.Model.join (get req))

  (* Evaluation *)

  let eval_lval lval req =
    let eval state = A.Eval.copy_lvalue state lval in
    LValue ( eval (get req))

  let eval_exp exp req =
    let eval state = A.Eval.evaluate state exp in
    Value ( eval (get req))

  let eval_address lval req =
    let eval state = A.Eval.lvaluate ~for_writing:false state lval in
    Address ( eval (get req), lval)
  let eval_callee exp req =
    let join = (@)
    and extract state =
      let r,_alarms = A.Eval.eval_function_exp exp state in
      r >>>-: fst
    get req |> Response.map_join' extract join |> convert |> (List.sort_uniq

  let extract_value :
    type c. (value, c) evaluation -> (A.Val.t or_bottom, c) Response.t =
      let extract (x, _alarms) = x >>- (fun (_valuation,fv) -> fv.Eval.v) in extract r
      let extract (x, _alarms) = x >>-: (fun (_valuation,v) -> v) in extract r

    match A.Val.get Main_values.CVal.key with
    | None ->
      Result.error DisabledDomain
    | Some get ->
      let join = Main_values.CVal.join in
      let extract value =
      extract_value res |> Response.map_join' extract join |> convert
  let extract_loc :
    type c. (address, c) evaluation ->
    (A.Loc.location or_bottom, c) Response.t * Cil_types.lval =
      let extract (x, _alarms) = x >>-: (fun (_valuation,loc,_typ) -> loc) in
    match A.Loc.get Main_locations.PLoc.key with
    | None ->
      Result.error DisabledDomain
    | Some get ->
      let join loc1 loc2 =
        let open Locations in
        let size = loc1.size
        and loc = Location_Bits.join loc1.loc loc2.loc in
        assert (Int_Base.equal loc2.size size);
        make_loc loc size
      and extract loc =
        loc  >>>-: get >>>-: Precise_locs.imprecise_location
      extract_loc res |> fst |> Response.map_join' extract join |> convert

  let as_zone ~access res =
    let response_loc, lv = extract_loc res in
    let is_const_lv = Value_util.is_const_write_invalid (Cil.typeOfLval lv) in
    (* No write effect if [lv] is const *)
    if access=Locations.Write && is_const_lv
    then Result.ok Locations.Zone.bottom
      match A.Loc.get Main_locations.PLoc.key with
      | None ->
        Result.error DisabledDomain
      | Some get ->
        let join = Locations.Zone.join
        and extract loc =
          loc  >>>-: get >>>-: Precise_locs.enumerate_valid_bits access
        response_loc |> Response.map_join' extract join |> convert
  let is_initialized : type c. (value,c) evaluation -> bool =
      let join = (&&)
      and extract (x, _alarms) =
        x >>>-: (fun (_valuation,fv) -> fv.Eval.initialized)
      begin match Response.map_join' extract join r with
        | `Bottom | `Top -> false
        | `Value v -> v
    | Value _ -> true (* computed values are always initialized *)
  let alarms : type a c. (a,c) evaluation -> Alarms.t list =
    fun res ->
    let extract (_,v) = `Value v in
    let r = match res with
        Response.map_join' extract Alarmset.union r
      | Value r ->
        Response.map_join' extract Alarmset.union r
        Response.map_join' extract Alarmset.union r
    match r with
    | `Bottom | `Top -> []
    | `Value alarmset ->
      let open Alarmset in
      let l = ref [] in
      let add alarm = function
        | True -> ()
        | False | Unknown -> l := alarm :: !l
      Alarmset.iter add alarmset;

  let is_bottom : type a c. (a,c) evaluation -> bool =
    let extract (x,_) = x >>>-: fun _ -> () in
    let join () () = () in
    let r = match res with
      | LValue r ->
        Response.map_join' extract join r
      | Value r ->
        Response.map_join' extract join r
        Response.map_join' extract join r
    match r with
    | `Bottom -> true
    | `Top | `Value () -> false

  (* Dependencies *)

  let lval_deps lval req =
    let compute_deps cvalue =
      Register.eval_deps_lval (cvalue, Locals_scoping.bottom ()) lval
    req |> as_cvalue_model |>
    Result.fold ~error:(fun _ -> Locations.Zone.bottom) ~ok:compute_deps

  let expr_deps exp req =
    let compute_deps cvalue =
      Register.eval_deps (cvalue, Locals_scoping.bottom ()) exp
    req |> as_cvalue_model |>
    Result.fold ~error:(fun _ -> Locations.Zone.bottom) ~ok:compute_deps

  let address_deps lval req =
    let compute_deps cvalue =
      Register.eval_deps_addr (cvalue, Locals_scoping.bottom ()) lval
    req |> as_cvalue_model |>
    Result.fold ~error:(fun _ -> Locations.Zone.bottom) ~ok:compute_deps

(* Working with callstacks *)

let callstacks req =
  let module E = Make () in
  E.callstacks req

let iter_callstacks f acc =
  let module E = Make () in
  E.iter_callstacks f acc

let fold_callstacks f acc req =
  let module E = Make () in
  E.fold_callstacks f acc req

let by_callstack req =
  let module E = Make () in
  E.by_callstack req

(* State requests *)

let equality_class exp req =
  let module E = Make () in
  E.equality_class exp req

let as_cvalue_model req =
  let module E = Make () in
  E.as_cvalue_model req
(* Depedencies *)

let expr_deps exp req =
  let module E = Make () in
  E.expr_deps exp req

let lval_deps lval req =
  let module E = Make () in
  E.lval_deps lval req

let address_deps lval req =
  let module E = Make () in
  E.address_deps lval req

module type Lvaluation =
  include module type of (Make ())
  type restriction
  val v : (address,restriction) evaluation

module type Evaluation =
  include module type of (Make ())
  type restriction
  val v : (value,restriction) evaluation
type 'a evaluation =
  | Value: (module Evaluation) -> value evaluation
  | Address: (module Lvaluation) -> address evaluation
let build_eval_lval_and_exp () =
  let module M = Make () in
  let open Response in
  let build = function
    | M.LValue (Consolidated _)
    | M.Value (Consolidated _) as eval ->
      (module struct
        include M
        type restriction = unrestricted_response
        let v = eval
      end : Evaluation)
    | M.LValue (ByCallstack _ | Top | Bottom)
    | M.Value (ByCallstack _ | Top | Bottom) as eval ->
      (module struct
        include M
        type restriction = restricted_to_callstack
        let v = eval
      end : Evaluation)
  let eval_lval lval req = build @@ M.eval_lval lval req in
  let eval_exp exp req = build @@ M.eval_exp exp req in
  eval_lval, eval_exp
let eval_lval lval req = Value ((fst @@ build_eval_lval_and_exp ()) lval req)
let eval_var vi req = eval_lval (Cil.var vi) req
let eval_exp exp req = Value ((snd @@ build_eval_lval_and_exp ()) exp req)
let eval_address lval req =
  let module M = Make () in
  let open Response in
  match M.eval_address lval req with
  | M.Address (Consolidated _, _) as lval ->
      include M
      type restriction = unrestricted_response
      let v = lval
    end : Lvaluation)
  | M.Address ((ByCallstack _ | Top | Bottom), _) as lval ->
      include M
      type restriction = restricted_to_callstack
      let v = lval
    end : Lvaluation)
  (* Check the validity of exp *)
  begin match exp with
    | Cil_types.({ enode = Lval (_, NoOffset) }) -> ()
    | _ ->
      invalid_arg "The callee must be an lvalue with no offset"
  let module M = Make () in
  M.eval_callee exp req

let callee stmt =
  let callee_exp =
    match stmt.Cil_types.skind with
    | Instr (Call (_lval, callee_exp, _args, _loc)) ->
    | Instr (Local_init (_vi, ConsInit (f, _, _), _loc)) ->
      Cil.evar f
    | _ ->
      invalid_arg "Can only evaluate the callee on a statement which is a Call"
  before stmt |> eval_callee callee_exp |> Result.value ~default:[]

(* Value conversion *)

let as_cvalue (Value evaluation) =
  let module E = (val evaluation : Evaluation) in
  E.as_cvalue E.v

let as_ival evaluation =
  try Cvalue.V.project_ival (as_cvalue evaluation)
  with Cvalue.V.Not_based_on_null ->
    Result.error Top

let as_fval evaluation =
  let f ival =
    if Ival.is_float ival
    then Result.ok (Ival.project_float ival)
    else Result.error Top
  Result.bind (as_ival evaluation) f

let as_float evaluation =
    as_fval evaluation |> Fval.project_float |> Fval.F.to_float
  with Fval.Not_Singleton_Float ->
    Result.error Top

let as_integer evaluation =
  try Ival.project_int (as_ival evaluation)
  with Ival.Not_Singleton_Int ->
    Result.error Top

let as_int evaluation =
  try Integer.to_int_exn (as_integer evaluation)
  with Z.Overflow ->
    Result.error Top

let as_location (Address lvaluation) =
  let module E = (val lvaluation : Lvaluation) in
  E.as_location E.v

let as_zone_result ?(access=Locations.Read) (Address lvaluation) =
  let module E = (val lvaluation : Lvaluation) in
let as_zone ?access address =
  match as_zone_result ?access address with
  | Ok zone -> zone
  | Error Bottom -> Locations.Zone.bottom
  | Error (Top | DisabledDomain) ->

(* Evaluation properties *)

let is_initialized (Value evaluation) =
  let module E = (val evaluation : Evaluation) in
  E.is_initialized E.v

let alarms : type a. a evaluation -> Alarms.t list =
  | Value evaluation ->
    let module E = (val evaluation : Evaluation) in
    E.alarms E.v
  | Address lvaluation ->
    let module L = (val lvaluation : Lvaluation) in
    L.alarms L.v
let is_empty rq =
  let module E = Make () in
  E.callstacks rq = []

let is_bottom : type a. a evaluation -> bool =
  | Value evaluation ->
    let module E = (val evaluation : Evaluation) in
    E.is_bottom E.v
  | Address lvaluation ->
    let module L = (val lvaluation : Lvaluation) in
    L.is_bottom L.v

let is_called kf =
  let module M = Make () in
  M.is_reachable (at_start_of kf)

let is_reachable stmt =
  let module M = Make () in
  M.is_reachable (before stmt)
let is_reachable_kinstr kinstr =
  let module M = Make () in
  M.is_reachable (before_kinstr kinstr)

(* Callers / callsites *)

let callers kf =
  let f = function
    | [] | [_] -> None
    | _ :: (caller,_) :: _-> Some caller
  at_start_of kf |> callstacks |>
  List.filter_map f |> List.sort_uniq

let uniq_sites = List.sort_uniq

  let module Map = Kernel_function.Map in
  let f acc = function
    | [] | (_,Cil_types.Kglobal) :: _ -> acc
    | [(_,Kstmt _)] -> assert false (* End of callstacks should have no callsite *)
    | (_kf,Kstmt stmt) :: (caller,_) :: _ -> (* kf = _kf *)
      Map.update caller
        (fun old -> Some (stmt :: Option.value ~default:[] old)) acc
  at_start_of kf |> callstacks |>
  List.fold_left f Map.empty |> Map.to_seq |> List.of_seq |> (fun (kf,sites) -> kf, uniq_sites sites)

(* Result conversion *)

let default default_value result =
  Result.value ~default:default_value result