diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index 1acd424808fd1da9efbd4eb22708b341d87a19d1..e0fa3156946e7f63f4c76878db89e6f34a66660a 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -53,8 +53,7 @@ let apply_classic_prover env task = let apply_native_nn_prover task = let lookup = Language.lookup_nn in let trans = - Trans.seq - [ Introduction.introduce_premises; Native_nn_prover.trans_nn_application ] + Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans ] in let tasks = Trans.apply Split_goal.split_goal_full task in do_apply_prover ~lookup ~trans tasks diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 6af0d9279e060fef414f8e3038c76f93854e6ae6..a5a1d23f6ae2c0d6ff638b9a853f9100a716c101 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -23,11 +23,14 @@ open Why3 open Base +(* Collects in a map the input variables, already declared in a task, and their + indices of appearance inside respective input vectors. Such collecting + process is memoized wrt lsymbols corresponding to input vectors. *) let collect_input_vars = let hls = Term.Hls.create 13 in let add index mls = function | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls index mls - | t -> failwith (Fmt.str "No input variable: %a" Pretty.print_term t) + | t -> failwith (Fmt.str "Not an input variable: %a" Pretty.print_term t) in let rec do_collect mls (term : Term.term) = match term.t_node with @@ -51,6 +54,12 @@ let collect_input_vars = (fun decl mls -> Decl.decl_fold do_collect mls decl) Term.Mls.empty +(* Creates a list of pairs made of output variables and respective indices in + the list, for each neural network application to an input vector appearing in + a task. Such a list stands for the resulting output vector of a neural + network application to an input vector (ie, something of the form: nn@@v). + The creation process is memoized wrt terms corresponding to neural network + applications to input vectors. *) let create_output_vars = let rec do_create mt (term : Term.term) = match term.t_node with @@ -74,6 +83,12 @@ let create_output_vars = (fun decl mt -> Decl.decl_fold do_create mt decl) Term.Mterm.empty +(* Simplifies a task goal exhibiting a vector selection on a neural network + application to an input vector (ie, (nn@@v)[_]) by the corresponding output + variable. Morevoer, each input variable declaration is annotated with a meta + that describes the respective index in the input vector. Ouput variables are + all declared, each with a meta that describes the respective index in the + output vector. *) let simplify_nn_application input_vars output_vars = let rec do_simplify (term : Term.term) = match term.t_node with @@ -98,11 +113,9 @@ let simplify_nn_application input_vars output_vars = (fun task_hd task -> match task_hd.task_decl.td_node with | Decl { d_node = Dparam ls; _ } -> - (* Add the meta first, then the actual input variable declaration. This - is mandatory for allowing some printers to properly work. Such - printers need to collect metas and print the actual declarations if a - corresponding meta has been already collected. Hence metas must - appear before corresponding declarations. *) + (* Add meta for neural network and input variable declarations. Note + that each meta needs to appear before the corresponding declaration + in order to be leveraged by prover printers. *) let task = match (Term.Mls.find_opt ls input_vars, Language.lookup_nn ls) with | None, None -> task @@ -118,10 +131,13 @@ let simplify_nn_application input_vars output_vars = | Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) -> let decl = Decl.decl_map do_simplify decl in let task = + (* Output variables are not declared yet in the task as they are + created on the fly for each (different) neural network application + on an input vector. We add here their declarations in the task. *) Term.Mterm.fold (fun _t output_vars task -> - (* Add the meta first, then the actual output variable - declaration. Same reason as for input variable declarations. *) + (* Again, for each output variable, add the meta first, then its + actual declaration. *) List.fold output_vars ~init:task ~f:(fun task (index, output_var) -> let task = @@ -137,7 +153,7 @@ let simplify_nn_application input_vars output_vars = Task.add_tdecl task task_hd.task_decl) None -let trans_nn_application = +let trans = Trans.bind collect_input_vars (fun input_vars -> Trans.bind create_output_vars (fun output_vars -> simplify_nn_application input_vars output_vars)) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 090612fbdd8423b21677a1268c08ce7398cff807..2978f4fd9e12d191801991e334cd121fd8c69530 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,4 +20,4 @@ (* *) (**************************************************************************) -val trans_nn_application : Why3.Task.task Why3.Trans.trans +val trans : Why3.Task.task Why3.Trans.trans