Skip to content
Snippets Groups Projects
Commit 94016493 authored by Michele Alberti's avatar Michele Alberti
Browse files

[interpretation] Extension of current transformations.

parent 4e558db3
No related branches found
No related tags found
No related merge requests found
...@@ -386,6 +386,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -386,6 +386,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
term (term_of_caisar_op engine caisar_op ty) term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls) | _ -> invalid_arg (error_message ls)
in in
[ [
( [ "interpretation" ], ( [ "interpretation" ],
"Vector", "Vector",
...@@ -409,6 +410,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -409,6 +410,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
[ [
([ "read_classifier" ], None, read_classifier); ([ "read_classifier" ], None, read_classifier);
([ Ident.op_infix "@@" ], None, apply_classifier); ([ Ident.op_infix "@@" ], None, apply_classifier);
([ Ident.op_infix "%%" ], None, apply_classifier);
] ); ] );
( [ "interpretation" ], ( [ "interpretation" ],
"Dataset", "Dataset",
......
...@@ -22,17 +22,36 @@ ...@@ -22,17 +22,36 @@
open Why3 open Why3
let do_apply_prover trans task = let apply_classic_prover env task =
let nb = Trans.apply Utils.count_nn_apply task in let nb = Trans.apply Utils.count_nn_apply task in
match nb with match nb with
| 0 -> task | 0 -> task
| 1 -> Trans.apply trans task | 1 -> Trans.apply (Nn2smt.trans env) task
| _ -> | _ ->
invalid_arg "Two or more neural network applications are not supported yet" invalid_arg "Two or more neural network applications are not supported yet"
let apply_classic_prover env task = do_apply_prover (Nn2smt.trans env) task
let apply_native_nn_prover env task = let apply_native_nn_prover env task =
do_apply_prover let nb_nn_apply = Trans.apply Utils.count_nn_apply task in
(Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans env ]) let nb_nn_classifiers = Trans.apply Utils.count_nn_classifiers task in
task match (nb_nn_apply, nb_nn_classifiers) with
| 0, 0 -> task
| 1, 0 ->
Trans.(
apply
(seq
[
Introduction.introduce_premises;
Native_nn_prover.trans_nn_apply env;
]))
task
| 0, 1 ->
Trans.(
apply
(seq
[
Introduction.introduce_premises;
Native_nn_prover.trans_nn_classifier env;
]))
task
| _ ->
invalid_arg "Two or more neural network applications are not supported yet"
...@@ -92,5 +92,87 @@ let simplify_goal env input_variables = ...@@ -92,5 +92,87 @@ let simplify_goal env input_variables =
Task.add_decl acc decl) Task.add_decl acc decl)
None None
let trans env = let trans_nn_apply env =
Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
(* Create logic symbols for output variables and simplify the formula. *)
let simplify_goal _env input_variables =
let rec aux hls (term : Term.term) =
match term.t_node with
| Term.Tapp
( ls_vget,
[
({
t_node =
Tapp
( ls_apply_classifier,
[
{ t_node = Tapp (ls_nn_classifier, _); _ };
{ t_node = Tapp (ls_vector, _); _ };
] );
_;
} as _t1);
({ t_node = Tconst (ConstInt i); _ } as _t2);
] )
when String.equal ls_vget.ls_name.id_string (Ident.op_get "")
&& String.equal ls_apply_classifier.ls_name.id_string
(Ident.op_infix "%%") -> (
match
( Language.lookup_nn_classifier ls_nn_classifier,
Language.lookup_vector ls_vector )
with
| Some nn, Some _ ->
let index = Number.to_small_integer i in
let hout =
Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout ->
let ls =
let id = Ident.id_fresh "y" in
Term.create_fsymbol id [] nn.nn_ty_elt
in
match hout with
| None ->
let hout = Hashtbl.create (module Int) in
Hashtbl.add_exn hout ~key:index ~data:ls;
hout
| Some hout ->
Hashtbl.update hout index ~f:(fun lsout ->
match lsout with
| None ->
Hashtbl.add_exn hout ~key:index ~data:ls;
ls
| Some ls -> ls);
hout)
in
let ls_output = Hashtbl.find_exn hout index in
Term.fs_app ls_output [] nn.nn_ty_elt
| _ -> Term.t_map (aux hls) term)
| _ -> Term.t_map (aux hls) term
in
let htbl = Hashtbl.create (module String) in
Trans.fold
(fun task_hd acc ->
match task_hd.task_decl.td_node with
| Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl
| Decl { d_node = Dparam ls; _ } -> (
let task = Task.add_tdecl acc task_hd.task_decl in
match Term.Mls.find_opt ls input_variables with
| None -> task
| Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ]
)
| Decl decl ->
let decl = Decl.decl_map (fun term -> aux htbl term) decl in
let acc =
Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc ->
let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in
Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc ->
let acc =
let decl = Decl.create_param_decl data in
Task.add_decl acc decl
in
Task.add_meta acc Utils.meta_output [ MAls data; MAint key ]))
in
Task.add_decl acc decl)
None
let trans_nn_classifier env =
Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
...@@ -22,4 +22,5 @@ ...@@ -22,4 +22,5 @@
open Why3 open Why3
val trans : Env.env -> Task.task Trans.trans val trans_nn_apply : Env.env -> Task.task Trans.trans
val trans_nn_classifier : Env.env -> Task.task Trans.trans
...@@ -35,21 +35,40 @@ let count_nn_apply = ...@@ -35,21 +35,40 @@ let count_nn_apply =
in in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
let count_nn_classifiers =
let rec aux acc (term : Term.term) =
let acc = Term.t_fold aux acc term in
match term.t_node with
| Term.Tapp (ls, _) -> (
match Language.lookup_nn_classifier ls with
| None -> acc
| Some _ -> acc + 1)
| _ -> acc
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
let get_input_variables = let get_input_variables =
let add i acc = function
| { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a" Pretty.print_term arg)
in
let rec aux acc (term : Term.term) = let rec aux acc (term : Term.term) =
match term.t_node with match term.t_node with
| Term.Tapp
( { ls_name; _ },
[ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
| Some { nn_inputs; _ }, Some n ->
assert (nn_inputs = n && n = List.length args);
List.foldi ~init:acc ~f:add args
| _ -> acc)
| Term.Tapp (ls, args) -> ( | Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nets ls with match Language.lookup_loaded_nets ls with
| None -> acc | None -> acc
| Some _ -> | Some _ -> List.foldi ~init:acc ~f:add args)
let add i acc = function
| { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a" Pretty.print_term
arg)
in
List.foldi ~init:acc ~f:add args)
| _ -> Term.t_fold aux acc term | _ -> Term.t_fold aux acc term
in in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty
......
...@@ -25,8 +25,12 @@ open Why3 ...@@ -25,8 +25,12 @@ open Why3
val count_nn_apply : int Trans.trans val count_nn_apply : int Trans.trans
(** Count the number of applications of [nn_apply]. *) (** Count the number of applications of [nn_apply]. *)
val count_nn_classifiers : int Trans.trans
(** Count the number of applications of a NN classifier. *)
val get_input_variables : int Term.Mls.t Trans.trans val get_input_variables : int Term.Mls.t Trans.trans
(** Retrieve the input variables appearing as arguments of [nn_apply]. *) (** Retrieve the input variables appearing as arguments of [nn_apply] or a NN
classifier. *)
val meta_input : Theory.meta val meta_input : Theory.meta
(** Indicate the input position. *) (** Indicate the input position. *)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment