From 94016493e0d8d514b3da70e382d91713ecb692e3 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Fri, 7 Apr 2023 11:31:09 +0200 Subject: [PATCH] [interpretation] Extension of current transformations. --- src/interpretation.ml | 2 + src/proof_strategy.ml | 33 ++++++++-- src/transformations/native_nn_prover.ml | 84 +++++++++++++++++++++++- src/transformations/native_nn_prover.mli | 3 +- src/transformations/utils.ml | 37 ++++++++--- src/transformations/utils.mli | 6 +- 6 files changed, 146 insertions(+), 19 deletions(-) diff --git a/src/interpretation.ml b/src/interpretation.ml index d812197..b5d97f8 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -386,6 +386,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = term (term_of_caisar_op engine caisar_op ty) | _ -> invalid_arg (error_message ls) in + [ ( [ "interpretation" ], "Vector", @@ -409,6 +410,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = [ ([ "read_classifier" ], None, read_classifier); ([ Ident.op_infix "@@" ], None, apply_classifier); + ([ Ident.op_infix "%%" ], None, apply_classifier); ] ); ( [ "interpretation" ], "Dataset", diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index 041bee1..f4bccd1 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -22,17 +22,36 @@ open Why3 -let do_apply_prover trans task = +let apply_classic_prover env task = let nb = Trans.apply Utils.count_nn_apply task in match nb with | 0 -> task - | 1 -> Trans.apply trans task + | 1 -> Trans.apply (Nn2smt.trans env) task | _ -> invalid_arg "Two or more neural network applications are not supported yet" -let apply_classic_prover env task = do_apply_prover (Nn2smt.trans env) task - let apply_native_nn_prover env task = - do_apply_prover - (Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans env ]) - task + let nb_nn_apply = Trans.apply Utils.count_nn_apply task in + let nb_nn_classifiers = Trans.apply Utils.count_nn_classifiers task in + match (nb_nn_apply, nb_nn_classifiers) with + | 0, 0 -> task + | 1, 0 -> + Trans.( + apply + (seq + [ + Introduction.introduce_premises; + Native_nn_prover.trans_nn_apply env; + ])) + task + | 0, 1 -> + Trans.( + apply + (seq + [ + Introduction.introduce_premises; + Native_nn_prover.trans_nn_classifier env; + ])) + task + | _ -> + invalid_arg "Two or more neural network applications are not supported yet" diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 3edcb00..bc705a6 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -92,5 +92,87 @@ let simplify_goal env input_variables = Task.add_decl acc decl) None -let trans env = +let trans_nn_apply env = + Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] + +(* Create logic symbols for output variables and simplify the formula. *) +let simplify_goal _env input_variables = + let rec aux hls (term : Term.term) = + match term.t_node with + | Term.Tapp + ( ls_vget, + [ + ({ + t_node = + Tapp + ( ls_apply_classifier, + [ + { t_node = Tapp (ls_nn_classifier, _); _ }; + { t_node = Tapp (ls_vector, _); _ }; + ] ); + _; + } as _t1); + ({ t_node = Tconst (ConstInt i); _ } as _t2); + ] ) + when String.equal ls_vget.ls_name.id_string (Ident.op_get "") + && String.equal ls_apply_classifier.ls_name.id_string + (Ident.op_infix "%%") -> ( + match + ( Language.lookup_nn_classifier ls_nn_classifier, + Language.lookup_vector ls_vector ) + with + | Some nn, Some _ -> + let index = Number.to_small_integer i in + let hout = + Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout -> + let ls = + let id = Ident.id_fresh "y" in + Term.create_fsymbol id [] nn.nn_ty_elt + in + match hout with + | None -> + let hout = Hashtbl.create (module Int) in + Hashtbl.add_exn hout ~key:index ~data:ls; + hout + | Some hout -> + Hashtbl.update hout index ~f:(fun lsout -> + match lsout with + | None -> + Hashtbl.add_exn hout ~key:index ~data:ls; + ls + | Some ls -> ls); + hout) + in + let ls_output = Hashtbl.find_exn hout index in + Term.fs_app ls_output [] nn.nn_ty_elt + | _ -> Term.t_map (aux hls) term) + | _ -> Term.t_map (aux hls) term + in + let htbl = Hashtbl.create (module String) in + Trans.fold + (fun task_hd acc -> + match task_hd.task_decl.td_node with + | Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl + | Decl { d_node = Dparam ls; _ } -> ( + let task = Task.add_tdecl acc task_hd.task_decl in + match Term.Mls.find_opt ls input_variables with + | None -> task + | Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ] + ) + | Decl decl -> + let decl = Decl.decl_map (fun term -> aux htbl term) decl in + let acc = + Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc -> + let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in + Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc -> + let acc = + let decl = Decl.create_param_decl data in + Task.add_decl acc decl + in + Task.add_meta acc Utils.meta_output [ MAls data; MAint key ])) + in + Task.add_decl acc decl) + None + +let trans_nn_classifier env = Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 936ff49..84ea944 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -22,4 +22,5 @@ open Why3 -val trans : Env.env -> Task.task Trans.trans +val trans_nn_apply : Env.env -> Task.task Trans.trans +val trans_nn_classifier : Env.env -> Task.task Trans.trans diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml index 1cfbdda..a2279ec 100644 --- a/src/transformations/utils.ml +++ b/src/transformations/utils.ml @@ -35,21 +35,40 @@ let count_nn_apply = in Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 +let count_nn_classifiers = + let rec aux acc (term : Term.term) = + let acc = Term.t_fold aux acc term in + match term.t_node with + | Term.Tapp (ls, _) -> ( + match Language.lookup_nn_classifier ls with + | None -> acc + | Some _ -> acc + 1) + | _ -> acc + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 + let get_input_variables = + let add i acc = function + | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc + | arg -> + invalid_arg + (Fmt.str "No direct variable in application: %a" Pretty.print_term arg) + in let rec aux acc (term : Term.term) = match term.t_node with + | Term.Tapp + ( { ls_name; _ }, + [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) + when String.equal ls_name.id_string (Ident.op_infix "%%") -> ( + match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with + | Some { nn_inputs; _ }, Some n -> + assert (nn_inputs = n && n = List.length args); + List.foldi ~init:acc ~f:add args + | _ -> acc) | Term.Tapp (ls, args) -> ( match Language.lookup_loaded_nets ls with | None -> acc - | Some _ -> - let add i acc = function - | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc - | arg -> - invalid_arg - (Fmt.str "No direct variable in application: %a" Pretty.print_term - arg) - in - List.foldi ~init:acc ~f:add args) + | Some _ -> List.foldi ~init:acc ~f:add args) | _ -> Term.t_fold aux acc term in Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli index 12b8c13..8d54c04 100644 --- a/src/transformations/utils.mli +++ b/src/transformations/utils.mli @@ -25,8 +25,12 @@ open Why3 val count_nn_apply : int Trans.trans (** Count the number of applications of [nn_apply]. *) +val count_nn_classifiers : int Trans.trans +(** Count the number of applications of a NN classifier. *) + val get_input_variables : int Term.Mls.t Trans.trans -(** Retrieve the input variables appearing as arguments of [nn_apply]. *) +(** Retrieve the input variables appearing as arguments of [nn_apply] or a NN + classifier. *) val meta_input : Theory.meta (** Indicate the input position. *) -- GitLab