From 4a6046daaa71bb2c4a8f1a32478980afa3be64f1 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 31 May 2023 14:28:24 +0200 Subject: [PATCH] [trans] Add comments to and rename the native nn prover transformation. --- src/proof_strategy.ml | 3 +-- src/transformations/native_nn_prover.ml | 34 +++++++++++++++++------- src/transformations/native_nn_prover.mli | 2 +- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index 1acd424..e0fa315 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -53,8 +53,7 @@ let apply_classic_prover env task = let apply_native_nn_prover task = let lookup = Language.lookup_nn in let trans = - Trans.seq - [ Introduction.introduce_premises; Native_nn_prover.trans_nn_application ] + Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans ] in let tasks = Trans.apply Split_goal.split_goal_full task in do_apply_prover ~lookup ~trans tasks diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 6af0d92..a5a1d23 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -23,11 +23,14 @@ open Why3 open Base +(* Collects in a map the input variables, already declared in a task, and their + indices of appearance inside respective input vectors. Such collecting + process is memoized wrt lsymbols corresponding to input vectors. *) let collect_input_vars = let hls = Term.Hls.create 13 in let add index mls = function | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls index mls - | t -> failwith (Fmt.str "No input variable: %a" Pretty.print_term t) + | t -> failwith (Fmt.str "Not an input variable: %a" Pretty.print_term t) in let rec do_collect mls (term : Term.term) = match term.t_node with @@ -51,6 +54,12 @@ let collect_input_vars = (fun decl mls -> Decl.decl_fold do_collect mls decl) Term.Mls.empty +(* Creates a list of pairs made of output variables and respective indices in + the list, for each neural network application to an input vector appearing in + a task. Such a list stands for the resulting output vector of a neural + network application to an input vector (ie, something of the form: nn@@v). + The creation process is memoized wrt terms corresponding to neural network + applications to input vectors. *) let create_output_vars = let rec do_create mt (term : Term.term) = match term.t_node with @@ -74,6 +83,12 @@ let create_output_vars = (fun decl mt -> Decl.decl_fold do_create mt decl) Term.Mterm.empty +(* Simplifies a task goal exhibiting a vector selection on a neural network + application to an input vector (ie, (nn@@v)[_]) by the corresponding output + variable. Morevoer, each input variable declaration is annotated with a meta + that describes the respective index in the input vector. Ouput variables are + all declared, each with a meta that describes the respective index in the + output vector. *) let simplify_nn_application input_vars output_vars = let rec do_simplify (term : Term.term) = match term.t_node with @@ -98,11 +113,9 @@ let simplify_nn_application input_vars output_vars = (fun task_hd task -> match task_hd.task_decl.td_node with | Decl { d_node = Dparam ls; _ } -> - (* Add the meta first, then the actual input variable declaration. This - is mandatory for allowing some printers to properly work. Such - printers need to collect metas and print the actual declarations if a - corresponding meta has been already collected. Hence metas must - appear before corresponding declarations. *) + (* Add meta for neural network and input variable declarations. Note + that each meta needs to appear before the corresponding declaration + in order to be leveraged by prover printers. *) let task = match (Term.Mls.find_opt ls input_vars, Language.lookup_nn ls) with | None, None -> task @@ -118,10 +131,13 @@ let simplify_nn_application input_vars output_vars = | Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) -> let decl = Decl.decl_map do_simplify decl in let task = + (* Output variables are not declared yet in the task as they are + created on the fly for each (different) neural network application + on an input vector. We add here their declarations in the task. *) Term.Mterm.fold (fun _t output_vars task -> - (* Add the meta first, then the actual output variable - declaration. Same reason as for input variable declarations. *) + (* Again, for each output variable, add the meta first, then its + actual declaration. *) List.fold output_vars ~init:task ~f:(fun task (index, output_var) -> let task = @@ -137,7 +153,7 @@ let simplify_nn_application input_vars output_vars = Task.add_tdecl task task_hd.task_decl) None -let trans_nn_application = +let trans = Trans.bind collect_input_vars (fun input_vars -> Trans.bind create_output_vars (fun output_vars -> simplify_nn_application input_vars output_vars)) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 090612f..2978f4f 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,4 +20,4 @@ (* *) (**************************************************************************) -val trans_nn_application : Why3.Task.task Why3.Trans.trans +val trans : Why3.Task.task Why3.Trans.trans -- GitLab