(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)

type info = {
  le : Why3.Term.lsymbol;
  ge : Why3.Term.lsymbol;
  lt : Why3.Term.lsymbol;
  gt : Why3.Term.lsymbol;
  info_syn : Why3.Printer.syntax_map;
  variables : string Why3.Term.Hls.t;

let number_format =
    Why3.Number.long_int_support = `Default;
    Why3.Number.negative_int_support = `Default;
    Why3.Number.dec_int_support = `Default;
    Why3.Number.hex_int_support = `Unsupported;
    Why3.Number.oct_int_support = `Unsupported;
    Why3.Number.bin_int_support = `Unsupported;
    Why3.Number.negative_real_support =
      `Custom (fun fmt f -> fmt "-%t" f);
    Why3.Number.dec_real_support = `Default;
    Why3.Number.hex_real_support = `Unsupported;
    Why3.Number.frac_real_support = `Unsupported (fun _ _ -> assert false);

let rec print_term info fmt t =
  let open Why3 in
  match t.Term.t_node with
  | Tbinop ((Timplies | Tiff | Tor), _, _)
  | Tnot _ | Ttrue | Tfalse | Tvar _ | Tlet _ | Tif _ | Tcase _ | Tquant _
  | Teps _ ->
    Printer.unsupportedTerm t "Not supported by Marabou"
  | Tbinop (Tand, _, _) -> assert false (* Should appear only at top-level. *)
  | Tconst c -> Constant.(print number_format unsupported_escape) fmt c
  | Tapp (ls, l) -> (
    match Printer.query_syntax info.info_syn ls.ls_name with
    | Some s -> Printer.syntax_arguments s (print_term info) fmt l
    | None -> (
      match (Term.Hls.find_opt info.variables ls, l) with
      | Some s, [] -> Fmt.string fmt s
      | _ -> Printer.unsupportedTerm t "Unknown variable(s)"))

let rec print_top_level_term info fmt t =
  let open Why3 in
  (* Don't print things we don't know. *)
  let t_is_known =
      (fun _ -> true)
      (fun ls ->
        Ident.Mid.mem ls.ls_name info.info_syn || Term.Hls.mem info.variables ls)
  match t.Term.t_node with
  | Tquant _ -> ()
  | Tbinop (Tand, t1, t2) ->
    if t_is_known t1 && t_is_known t2
    then fmt "%a%a"
        (print_top_level_term info)
        (print_top_level_term info)
  | _ -> if t_is_known t then fmt "%a@." (print_term info) t

let rec negate_term info t =
  let open Why3 in
  match t.Term.t_node with
  | Tquant _ -> Printer.unsupportedTerm t "Quantification"
  | Tbinop (Tand, _, _) -> Printer.unsupportedTerm t "Conjunction"
  | Tbinop (Tor, t1, t2) ->
    Term.t_and (negate_term info t1) (negate_term info t2)
  | Tapp (ls, [ t1; t2 ]) ->
    let tt = [ t1; t2 ] in
    if Term.ls_equal ls info.le
    then Term.ps_app tt
    else if Term.ls_equal ls
    then Term.ps_app tt
    else if Term.ls_equal ls
    then Term.ps_app tt
    else if Term.ls_equal ls
    then Term.ps_app info.le tt
    else Printer.unsupportedTerm t "Cannot negate such term"
  | _ -> Printer.unsupportedTerm t "Cannot negate such term"

let print_decl info fmt d =
  let open Why3 in
  match d.Decl.d_node with
  | Dtype _ -> ()
  | Ddata _ -> ()
  | Dparam _ -> ()
  | Dlogic _ -> ()
  | Dind _ -> ()
  | Dprop (Decl.Plemma, _, _) -> assert false
  | Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f
  | Dprop (Decl.Pgoal, _, f) ->
    print_top_level_term info fmt (negate_term info f)

let rec print_tdecl info fmt task =
  let open Why3 in
  match task with
  | None -> ()
  | Some { Task.task_prev; task_decl; _ } -> (
    print_tdecl info fmt task_prev;
    match task_decl.Theory.td_node with
    | Use _ | Clone _ -> ()
    | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input
      -> (
      match l with
      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
      | _ -> assert false)
    | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output
      -> (
      match l with
      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
      | _ -> assert false)
    | Meta (_, _) -> ()
    | Decl d -> print_decl info fmt d)

let print_task args ?old:_ fmt task =
  let open Why3 in
  let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in
  let le = Theory.ns_find_ls th.th_export [ "le" ] in
  let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
  let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
  let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
  let info =
      info_syn = Discriminate.get_syntax_map task;
      variables = Term.Hls.create 10;
  Printer.print_prelude fmt args.Printer.prelude;
  print_tdecl info fmt task

let init () =
  Why3.Printer.register_printer ~desc:"Printer for the Marabou prover."
    "marabou" print_task