Skip to content
Snippets Groups Projects
interpretation.ml 11.13 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2022                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

module CRE = Reduction_engine (* Caisar Reduction Engine *)
open Why3
open Base

type dataset = CSV of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
[@@deriving show]

type classifier = string [@@deriving show]

type caisar_op =
  | Dataset of dataset
  | Data of string list
  | Classifier of classifier
  | ClassifierApp of Term.term * Term.term
      [@printer
        fun fmt (t1, t2) ->
          Fmt.pf fmt "%a@@%a" Pretty.print_term t1 Pretty.print_term t2]
  | VGet of Term.term * Term.term
      [@printer
        fun fmt (t1, t2) ->
          Fmt.pf fmt "%a[%a]" Pretty.print_term t1 Pretty.print_term t2]
  | EqualShape of Term.term * Term.term
      [@printer
        fun fmt (t1, t2) ->
          Fmt.pf fmt "EqShape %a %a" Pretty.print_term t1 Pretty.print_term t2]
  | ValidIndex of Term.term * Term.term
      [@printer
        fun fmt (t1, t2) ->
          Fmt.pf fmt "ValidIdx %a %a" Pretty.print_term t1 Pretty.print_term t2]
[@@deriving show]

type caisar_env = {
  caisar_op_of_ls : caisar_op Term.Hls.t;
  ls_of_caisar_op : (caisar_op, Term.lsymbol) Hashtbl.t;
  cwd : string;
}

let ls_of_caisar_op engine op ty =
  let caisar_env = CRE.user_env engine in
  Fmt.pr "ls_of_caisar_op: %a@." pp_caisar_op op;
  Option.iter ty ~f:(Fmt.pr "ty: %a@." Pretty.print_ty);
  Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () ->
    let id = Ident.id_fresh "caisar_op" in
    let ls = Term.create_lsymbol id [] ty in
    Fmt.pr "ls: %a@." Pretty.print_ls ls;
    Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls;
    Term.Hls.add caisar_env.caisar_op_of_ls ls op;
    ls)

let caisar_op_of_ls engine ls =
  let caisar_env = CRE.user_env engine in
  Term.Hls.find caisar_env.caisar_op_of_ls ls

let term_of_caisar_op engine op ty =
  Term.t_app_infer (ls_of_caisar_op engine op ty) []

let caisar_env _env cwd =
  {
    ls_of_caisar_op = Hashtbl.Poly.create ();
    caisar_op_of_ls = Term.Hls.create 10;
    cwd;
  }

let print_caisar_op fmt caisar_env =
  Pp.print_iter2 Term.Hls.iter Pp.newline Pp.comma Pretty.print_ls pp_caisar_op
    fmt caisar_env.caisar_op_of_ls

let builtin_caisar : caisar_env CRE.built_in_theories list =
  let error_message ls =
    Fmt.str "Invalid arguments for '%a'" Pretty.print_ls ls
  in
  (* Vector *)
  let vget : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.vget: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [
     Term ({ t_node = Tapp (lsapp, _); _ } as t1);
     Term ({ t_node = Tconst (ConstInt i); _ } as t2);
    ] -> (
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      match caisar_op_of_ls engine lsapp with
      | Dataset (CSV csv) ->
        let row = List.nth_exn csv (Number.to_small_integer i) in
        let label, features =
          match row with
          | [] | [ _ ] -> assert false
          | label :: features -> (label, features)
        in
        let ty_features =
          match ty with
          | Some { ty_node = Tyapp (_, [ a; _ ]); _ } -> Some a
          | _ -> assert false
        in
        let t_features, t_label =
          ( term_of_caisar_op engine (Data features) ty_features,
            Term.t_int_const (BigInt.of_int (Int.of_string label)) )
        in
        Term (Term.t_tuple [ t_features; t_label ])
      | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
      | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
        assert false)
    | [
     Term ({ t_node = Tapp (lsapp, _); _ } as t1);
     Term ({ t_node = Tvar _; _ } as t2);
    ] -> (
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      match caisar_op_of_ls engine lsapp with
      | Dataset _ -> assert false
      | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
      | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
        assert false)
    | _ -> invalid_arg (error_message ls)
  in
  let length : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.length: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [ Term { t_node = Tapp (ls, []); _ } ] -> (
      match caisar_op_of_ls engine ls with
      | Dataset (CSV csv) -> Int (BigInt.of_int (Csv.lines csv))
      | Data _ | Classifier _ | ClassifierApp _ | VGet _ | EqualShape _
      | ValidIndex _ ->
        assert false)
    | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (term_of_caisar_op engine (VGet (t1, t2)) ty)
    | _ -> invalid_arg (error_message ls)
  in

  (* Tensor *)
  let _valid_index : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [
        Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
      ]
    | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (term_of_caisar_op engine (ValidIndex (t1, t2)) ty)
      (* Term Term.t_true *)
    | _ -> invalid_arg (error_message ls)
  in
  let _equal_shape : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [
     Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
    ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (Term.t_app_infer ls [ t1; t2 ])
    | [ Term t1; Term t2 ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (term_of_caisar_op engine (EqualShape (t1, t2)) ty)
      (* Term Term.t_true *)
    | _ -> invalid_arg (error_message ls)
  in

  (* Classifier *)
  let read_classifier : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.read_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [
     Term { t_node = Tconst (ConstStr classifier); _ };
     Term { t_node = Tapp ({ ls_name = { id_string = "NNet"; _ }; _ }, []); _ };
    ] ->
      let cwd = (CRE.user_env engine).cwd in
      let caisar_op =
        let filename = Caml.Filename.concat cwd classifier in
        Classifier filename
      in
      Term (term_of_caisar_op engine caisar_op ty)
    | _ -> invalid_arg (error_message ls)
  in
  let apply_classifier : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.apply_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [ Term ({ t_node = Tvar _; _ } as t1); Term t2 ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (Term.t_app_infer ls [ t1; t2 ])
    | [ Term ({ t_node = Tapp (_lsapp, _); _ } as t1); Term t2 ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (Term.t_app_infer ls [ t1; t2 ])
    | [ Term t1; Term t2 ] ->
      Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
      Term (term_of_caisar_op engine (ClassifierApp (t1, t2)) ty)
    | _ -> invalid_arg (error_message ls)
  in

  (* Dataset *)
  let read_dataset : _ CRE.builtin =
   fun engine ls vl ty ->
    Fmt.pr "--@.read_dataset: ls:%a , ty:%a@." Pretty.print_ls ls
      Fmt.(option ~none:nop Pretty.print_ty)
      ty;
    match vl with
    | [
     Term { t_node = Tconst (ConstStr dataset); _ };
     Term { t_node = Tapp ({ ls_name = { id_string = "CSV"; _ }; _ }, []); _ };
    ] ->
      let { cwd; _ } = CRE.user_env engine in
      let caisar_op =
        let filename = Caml.Filename.concat cwd dataset in
        let dataset = CSV (Csv.load filename) in
        Dataset dataset
      in
      Term (term_of_caisar_op engine caisar_op ty)
    | _ -> invalid_arg (error_message ls)
  in
  [
    ( [ "interpretation" ],
      "Vector",
      [],
      [ (Ident.op_get "" (* ([]) *), None, vget); ("length", None, length) ] );
    ( [ "interpretation" ],
      "Tensor",
      [],
      [ (* ("valid_index", None, valid_index); *)
        (* ("equal_shape", None, equal_shape); *) ] );
    ( [ "interpretation" ],
      "Classifier",
      [],
      [
        ("read_classifier", None, read_classifier);
        (Ident.op_infix "@@", None, apply_classifier);
      ] );
    ( [ "interpretation" ],
      "Dataset",
      [],
      [ ("read_dataset", None, read_dataset) ] );
  ]

let interpret_task ~cwd env task =
  let known_map = Task.task_known task in
  let caisar_env = caisar_env env cwd in
  let params =
    {
      CRE.compute_defs = true;
      compute_builtin = true;
      compute_def_set = Term.Sls.empty;
      compute_max_quantifier_domain = Int.max_value;
    }
  in
  let engine = CRE.create params env known_map caisar_env builtin_caisar in
  let f = Task.task_goal_fmla task in
  Fmt.pr "TERM: %a@." Pretty.print_term f;
  let f = CRE.normalize ~limit:1000 engine Term.Mvs.empty f in
  Fmt.pr "%a : %a@.%a@." Pretty.print_pr (Task.task_goal task) Pretty.print_term
    f print_caisar_op caisar_env

let interpret ?debug ?format ~loadpath file =
  let cwd =
    match file with
    | Verification.File.Stdin -> Unix.getcwd ()
    | File s -> Caml.Filename.dirname s
    | JSON _ -> Unix.getcwd () (*todo *)
  in
  let env, _config, mstr_theory =
    Verification.open_file ?debug ?format ~loadpath file
  in
  Wstdlib.Mstr.iter
    (fun _ theory ->
      List.iter (Task.split_theory theory None None) ~f:(fun task ->
        interpret_task ~cwd env task))
    mstr_theory