Skip to content
Snippets Groups Projects
Commit 94016493 authored by Michele Alberti's avatar Michele Alberti
Browse files

[interpretation] Extension of current transformations.

parent 4e558db3
No related branches found
No related tags found
No related merge requests found
......@@ -386,6 +386,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls)
in
[
( [ "interpretation" ],
"Vector",
......@@ -409,6 +410,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
[
([ "read_classifier" ], None, read_classifier);
([ Ident.op_infix "@@" ], None, apply_classifier);
([ Ident.op_infix "%%" ], None, apply_classifier);
] );
( [ "interpretation" ],
"Dataset",
......
......@@ -22,17 +22,36 @@
open Why3
let do_apply_prover trans task =
let apply_classic_prover env task =
let nb = Trans.apply Utils.count_nn_apply task in
match nb with
| 0 -> task
| 1 -> Trans.apply trans task
| 1 -> Trans.apply (Nn2smt.trans env) task
| _ ->
invalid_arg "Two or more neural network applications are not supported yet"
let apply_classic_prover env task = do_apply_prover (Nn2smt.trans env) task
let apply_native_nn_prover env task =
do_apply_prover
(Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans env ])
task
let nb_nn_apply = Trans.apply Utils.count_nn_apply task in
let nb_nn_classifiers = Trans.apply Utils.count_nn_classifiers task in
match (nb_nn_apply, nb_nn_classifiers) with
| 0, 0 -> task
| 1, 0 ->
Trans.(
apply
(seq
[
Introduction.introduce_premises;
Native_nn_prover.trans_nn_apply env;
]))
task
| 0, 1 ->
Trans.(
apply
(seq
[
Introduction.introduce_premises;
Native_nn_prover.trans_nn_classifier env;
]))
task
| _ ->
invalid_arg "Two or more neural network applications are not supported yet"
......@@ -92,5 +92,87 @@ let simplify_goal env input_variables =
Task.add_decl acc decl)
None
let trans env =
let trans_nn_apply env =
Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
(* Create logic symbols for output variables and simplify the formula. *)
let simplify_goal _env input_variables =
let rec aux hls (term : Term.term) =
match term.t_node with
| Term.Tapp
( ls_vget,
[
({
t_node =
Tapp
( ls_apply_classifier,
[
{ t_node = Tapp (ls_nn_classifier, _); _ };
{ t_node = Tapp (ls_vector, _); _ };
] );
_;
} as _t1);
({ t_node = Tconst (ConstInt i); _ } as _t2);
] )
when String.equal ls_vget.ls_name.id_string (Ident.op_get "")
&& String.equal ls_apply_classifier.ls_name.id_string
(Ident.op_infix "%%") -> (
match
( Language.lookup_nn_classifier ls_nn_classifier,
Language.lookup_vector ls_vector )
with
| Some nn, Some _ ->
let index = Number.to_small_integer i in
let hout =
Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout ->
let ls =
let id = Ident.id_fresh "y" in
Term.create_fsymbol id [] nn.nn_ty_elt
in
match hout with
| None ->
let hout = Hashtbl.create (module Int) in
Hashtbl.add_exn hout ~key:index ~data:ls;
hout
| Some hout ->
Hashtbl.update hout index ~f:(fun lsout ->
match lsout with
| None ->
Hashtbl.add_exn hout ~key:index ~data:ls;
ls
| Some ls -> ls);
hout)
in
let ls_output = Hashtbl.find_exn hout index in
Term.fs_app ls_output [] nn.nn_ty_elt
| _ -> Term.t_map (aux hls) term)
| _ -> Term.t_map (aux hls) term
in
let htbl = Hashtbl.create (module String) in
Trans.fold
(fun task_hd acc ->
match task_hd.task_decl.td_node with
| Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl
| Decl { d_node = Dparam ls; _ } -> (
let task = Task.add_tdecl acc task_hd.task_decl in
match Term.Mls.find_opt ls input_variables with
| None -> task
| Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ]
)
| Decl decl ->
let decl = Decl.decl_map (fun term -> aux htbl term) decl in
let acc =
Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc ->
let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in
Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc ->
let acc =
let decl = Decl.create_param_decl data in
Task.add_decl acc decl
in
Task.add_meta acc Utils.meta_output [ MAls data; MAint key ]))
in
Task.add_decl acc decl)
None
let trans_nn_classifier env =
Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
......@@ -22,4 +22,5 @@
open Why3
val trans : Env.env -> Task.task Trans.trans
val trans_nn_apply : Env.env -> Task.task Trans.trans
val trans_nn_classifier : Env.env -> Task.task Trans.trans
......@@ -35,21 +35,40 @@ let count_nn_apply =
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
let count_nn_classifiers =
let rec aux acc (term : Term.term) =
let acc = Term.t_fold aux acc term in
match term.t_node with
| Term.Tapp (ls, _) -> (
match Language.lookup_nn_classifier ls with
| None -> acc
| Some _ -> acc + 1)
| _ -> acc
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
let get_input_variables =
let add i acc = function
| { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a" Pretty.print_term arg)
in
let rec aux acc (term : Term.term) =
match term.t_node with
| Term.Tapp
( { ls_name; _ },
[ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
| Some { nn_inputs; _ }, Some n ->
assert (nn_inputs = n && n = List.length args);
List.foldi ~init:acc ~f:add args
| _ -> acc)
| Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nets ls with
| None -> acc
| Some _ ->
let add i acc = function
| { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a" Pretty.print_term
arg)
in
List.foldi ~init:acc ~f:add args)
| Some _ -> List.foldi ~init:acc ~f:add args)
| _ -> Term.t_fold aux acc term
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty
......
......@@ -25,8 +25,12 @@ open Why3
val count_nn_apply : int Trans.trans
(** Count the number of applications of [nn_apply]. *)
val count_nn_classifiers : int Trans.trans
(** Count the number of applications of a NN classifier. *)
val get_input_variables : int Term.Mls.t Trans.trans
(** Retrieve the input variables appearing as arguments of [nn_apply]. *)
(** Retrieve the input variables appearing as arguments of [nn_apply] or a NN
classifier. *)
val meta_input : Theory.meta
(** Indicate the input position. *)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment