-
Michele Alberti authoredMichele Alberti authored
interpretation.ml 11.13 KiB
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
module CRE = Reduction_engine (* Caisar Reduction Engine *)
open Why3
open Base
type dataset = CSV of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
[@@deriving show]
type classifier = string [@@deriving show]
type caisar_op =
| Dataset of dataset
| Data of string list
| Classifier of classifier
| ClassifierApp of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "%a@@%a" Pretty.print_term t1 Pretty.print_term t2]
| VGet of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "%a[%a]" Pretty.print_term t1 Pretty.print_term t2]
| EqualShape of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "EqShape %a %a" Pretty.print_term t1 Pretty.print_term t2]
| ValidIndex of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "ValidIdx %a %a" Pretty.print_term t1 Pretty.print_term t2]
[@@deriving show]
type caisar_env = {
caisar_op_of_ls : caisar_op Term.Hls.t;
ls_of_caisar_op : (caisar_op, Term.lsymbol) Hashtbl.t;
cwd : string;
}
let ls_of_caisar_op engine op ty =
let caisar_env = CRE.user_env engine in
Fmt.pr "ls_of_caisar_op: %a@." pp_caisar_op op;
Option.iter ty ~f:(Fmt.pr "ty: %a@." Pretty.print_ty);
Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () ->
let id = Ident.id_fresh "caisar_op" in
let ls = Term.create_lsymbol id [] ty in
Fmt.pr "ls: %a@." Pretty.print_ls ls;
Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls;
Term.Hls.add caisar_env.caisar_op_of_ls ls op;
ls)
let caisar_op_of_ls engine ls =
let caisar_env = CRE.user_env engine in
Term.Hls.find caisar_env.caisar_op_of_ls ls
let term_of_caisar_op engine op ty =
Term.t_app_infer (ls_of_caisar_op engine op ty) []
let caisar_env _env cwd =
{
ls_of_caisar_op = Hashtbl.Poly.create ();
caisar_op_of_ls = Term.Hls.create 10;
cwd;
}
let print_caisar_op fmt caisar_env =
Pp.print_iter2 Term.Hls.iter Pp.newline Pp.comma Pretty.print_ls pp_caisar_op
fmt caisar_env.caisar_op_of_ls
let builtin_caisar : caisar_env CRE.built_in_theories list =
let error_message ls =
Fmt.str "Invalid arguments for '%a'" Pretty.print_ls ls
in
(* Vector *)
let vget : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.vget: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term ({ t_node = Tapp (lsapp, _); _ } as t1);
Term ({ t_node = Tconst (ConstInt i); _ } as t2);
] -> (
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
match caisar_op_of_ls engine lsapp with
| Dataset (CSV csv) ->
let row = List.nth_exn csv (Number.to_small_integer i) in
let label, features =
match row with
| [] | [ _ ] -> assert false
| label :: features -> (label, features)
in
let ty_features =
match ty with
| Some { ty_node = Tyapp (_, [ a; _ ]); _ } -> Some a
| _ -> assert false
in
let t_features, t_label =
( term_of_caisar_op engine (Data features) ty_features,
Term.t_int_const (BigInt.of_int (Int.of_string label)) )
in
Term (Term.t_tuple [ t_features; t_label ])
| ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
| Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
assert false)
| [
Term ({ t_node = Tapp (lsapp, _); _ } as t1);
Term ({ t_node = Tvar _; _ } as t2);
] -> (
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
match caisar_op_of_ls engine lsapp with
| Dataset _ -> assert false
| ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ])
| Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ ->
assert false)
| _ -> invalid_arg (error_message ls)
in
let length : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.length: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [ Term { t_node = Tapp (ls, []); _ } ] -> (
match caisar_op_of_ls engine ls with
| Dataset (CSV csv) -> Int (BigInt.of_int (Csv.lines csv))
| Data _ | Classifier _ | ClassifierApp _ | VGet _ | EqualShape _
| ValidIndex _ ->
assert false)
| [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (VGet (t1, t2)) ty)
| _ -> invalid_arg (error_message ls)
in
(* Tensor *)
let _valid_index : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
]
| [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (ValidIndex (t1, t2)) ty)
(* Term Term.t_true *)
| _ -> invalid_arg (error_message ls)
in
let _equal_shape : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (Term.t_app_infer ls [ t1; t2 ])
| [ Term t1; Term t2 ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (EqualShape (t1, t2)) ty)
(* Term Term.t_true *)
| _ -> invalid_arg (error_message ls)
in
(* Classifier *)
let read_classifier : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.read_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term { t_node = Tconst (ConstStr classifier); _ };
Term { t_node = Tapp ({ ls_name = { id_string = "NNet"; _ }; _ }, []); _ };
] ->
let cwd = (CRE.user_env engine).cwd in
let caisar_op =
let filename = Caml.Filename.concat cwd classifier in
Classifier filename
in
Term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls)
in
let apply_classifier : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.apply_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [ Term ({ t_node = Tvar _; _ } as t1); Term t2 ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (Term.t_app_infer ls [ t1; t2 ])
| [ Term ({ t_node = Tapp (_lsapp, _); _ } as t1); Term t2 ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (Term.t_app_infer ls [ t1; t2 ])
| [ Term t1; Term t2 ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (ClassifierApp (t1, t2)) ty)
| _ -> invalid_arg (error_message ls)
in
(* Dataset *)
let read_dataset : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.read_dataset: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term { t_node = Tconst (ConstStr dataset); _ };
Term { t_node = Tapp ({ ls_name = { id_string = "CSV"; _ }; _ }, []); _ };
] ->
let { cwd; _ } = CRE.user_env engine in
let caisar_op =
let filename = Caml.Filename.concat cwd dataset in
let dataset = CSV (Csv.load filename) in
Dataset dataset
in
Term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls)
in
[
( [ "interpretation" ],
"Vector",
[],
[ (Ident.op_get "" (* ([]) *), None, vget); ("length", None, length) ] );
( [ "interpretation" ],
"Tensor",
[],
[ (* ("valid_index", None, valid_index); *)
(* ("equal_shape", None, equal_shape); *) ] );
( [ "interpretation" ],
"Classifier",
[],
[
("read_classifier", None, read_classifier);
(Ident.op_infix "@@", None, apply_classifier);
] );
( [ "interpretation" ],
"Dataset",
[],
[ ("read_dataset", None, read_dataset) ] );
]
let interpret_task ~cwd env task =
let known_map = Task.task_known task in
let caisar_env = caisar_env env cwd in
let params =
{
CRE.compute_defs = true;
compute_builtin = true;
compute_def_set = Term.Sls.empty;
compute_max_quantifier_domain = Int.max_value;
}
in
let engine = CRE.create params env known_map caisar_env builtin_caisar in
let f = Task.task_goal_fmla task in
Fmt.pr "TERM: %a@." Pretty.print_term f;
let f = CRE.normalize ~limit:1000 engine Term.Mvs.empty f in
Fmt.pr "%a : %a@.%a@." Pretty.print_pr (Task.task_goal task) Pretty.print_term
f print_caisar_op caisar_env
let interpret ?debug ?format ~loadpath file =
let cwd =
match file with
| Verification.File.Stdin -> Unix.getcwd ()
| File s -> Caml.Filename.dirname s
| JSON _ -> Unix.getcwd () (*todo *)
in
let env, _config, mstr_theory =
Verification.open_file ?debug ?format ~loadpath file
in
Wstdlib.Mstr.iter
(fun _ theory ->
List.iter (Task.split_theory theory None None) ~f:(fun task ->
interpret_task ~cwd env task))
mstr_theory