From ac7d9e18ce4f86a37e2996a7eb744848f29bfc12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Wed, 30 Nov 2022 11:23:10 +0100 Subject: [PATCH] Move the native_nn_prover transformation outside drivers. Make it into a new strategy module that should prepare the task before sending it to the provers. --- config/drivers/marabou.drv | 1 - config/drivers/pyrat.drv | 1 - src/main.ml | 1 - src/proof_strategy.ml | 35 ++++++++++++++++++++++++ src/proof_strategy.mli | 26 ++++++++++++++++++ src/transformations/native_nn_prover.ml | 7 +---- src/transformations/native_nn_prover.mli | 5 ++-- src/transformations/utils.ml | 30 ++++++++++++++++++-- src/transformations/utils.mli | 21 +++++++++++++- src/verification.ml | 5 ++-- 10 files changed, 116 insertions(+), 16 deletions(-) create mode 100644 src/proof_strategy.ml create mode 100644 src/proof_strategy.mli diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv index dc29a063..25eb4ea3 100644 --- a/config/drivers/marabou.drv +++ b/config/drivers/marabou.drv @@ -34,7 +34,6 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" -transformation "native_nn_prover" transformation "vars_on_lhs" theory BuiltIn diff --git a/config/drivers/pyrat.drv b/config/drivers/pyrat.drv index d9cc1c9e..26b23930 100644 --- a/config/drivers/pyrat.drv +++ b/config/drivers/pyrat.drv @@ -34,7 +34,6 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" -transformation "native_nn_prover" theory BuiltIn syntax type int "int" diff --git a/src/main.ml b/src/main.ml index c59b9740..e5f656dd 100644 --- a/src/main.ml +++ b/src/main.ml @@ -26,7 +26,6 @@ open Cmdliner let caisar = "caisar" let () = - Native_nn_prover.init (); Nn2smt.init (); Vars_on_lhs.init () diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml new file mode 100644 index 00000000..fc1556a2 --- /dev/null +++ b/src/proof_strategy.ml @@ -0,0 +1,35 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +let apply_native_nn_prover env task = + let nb = Trans.apply Utils.count_nn_apply task in + match nb with + | 0 -> task + | 1 -> + Trans.( + apply + (seq [ Introduction.introduce_premises; Native_nn_prover.trans env ])) + task + | _ -> + invalid_arg "Two or more neural network applications are not supported yet" diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli new file mode 100644 index 00000000..d8a52fde --- /dev/null +++ b/src/proof_strategy.mli @@ -0,0 +1,26 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +val apply_native_nn_prover : Env.env -> Task.task -> Task.task +(** Detect and execute applications of neural networks. *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index a7353ebc..3edcb00a 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -92,10 +92,5 @@ let simplify_goal env input_variables = Task.add_decl acc decl) None -let native_nn_prover env = +let trans env = Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] - -let init () = - Trans.register_env_transform - ~desc:"Transformation for provers that support loading neural networks." - "native_nn_prover" native_nn_prover diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 694097b3..936ff492 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,5 +20,6 @@ (* *) (**************************************************************************) -val init : unit -> unit -(** Register the transformation. *) +open Why3 + +val trans : Env.env -> Task.task Trans.trans diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml index efe26076..1cfbdda2 100644 --- a/src/transformations/utils.ml +++ b/src/transformations/utils.ml @@ -2,13 +2,39 @@ (* *) (* This file is part of CAISAR. *) (* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) (**************************************************************************) open Why3 open Base -(* Retrieve the (input) variables appearing as arguments of a logic symbol - nn_apply. *) +let count_nn_apply = + let rec aux acc (term : Term.term) = + let acc = Term.t_fold aux acc term in + match term.t_node with + | Term.Tapp (ls, _) -> ( + match Language.lookup_loaded_nets ls with + | None -> acc + | Some _ -> acc + 1) + | _ -> acc + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 + let get_input_variables = let rec aux acc (term : Term.term) = match term.t_node with diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli index ecfa4a03..12b8c134 100644 --- a/src/transformations/utils.mli +++ b/src/transformations/utils.mli @@ -2,12 +2,31 @@ (* *) (* This file is part of CAISAR. *) (* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) (**************************************************************************) open Why3 -open Base + +val count_nn_apply : int Trans.trans +(** Count the number of applications of [nn_apply]. *) val get_input_variables : int Term.Mls.t Trans.trans +(** Retrieve the input variables appearing as arguments of [nn_apply]. *) val meta_input : Theory.meta (** Indicate the input position. *) diff --git a/src/verification.ml b/src/verification.ml index fa183c52..bad75796 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -254,7 +254,8 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task = let prover_answer = combine_prover_answers answers in (prover_answer, None) -let answer_generic limit config prover config_prover driver task = +let answer_generic limit env config prover config_prover driver task = + let task = Proof_strategy.apply_native_nn_prover env task in let task = Driver.prepare_task driver task in let nn_file = match Task.on_meta_excl Utils.meta_nn_filename task with @@ -286,7 +287,7 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = let dataset_csv = Option.value_exn dataset_csv in answer_on_dataset limit config env config_prover driver dataset_csv task | Marabou | Pyrat | CVC5 -> - answer_generic limit config prover config_prover driver task + answer_generic limit env config prover config_prover driver task in Logs.app (fun m -> m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) -- GitLab