diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv index dc29a0636d0043008d82ccfd6e3b676fbc91f8b6..25eb4ea309ddb3dec339be83a97040e6e4370344 100644 --- a/config/drivers/marabou.drv +++ b/config/drivers/marabou.drv @@ -34,7 +34,6 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" -transformation "native_nn_prover" transformation "vars_on_lhs" theory BuiltIn diff --git a/config/drivers/pyrat.drv b/config/drivers/pyrat.drv index d9cc1c9e43d421d4d03bcfb76e04aa68ebee8da6..26b23930d2ba08100aa90ae37573a11620f5a203 100644 --- a/config/drivers/pyrat.drv +++ b/config/drivers/pyrat.drv @@ -34,7 +34,6 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" -transformation "native_nn_prover" theory BuiltIn syntax type int "int" diff --git a/src/main.ml b/src/main.ml index c59b9740df57988f8f391843a1640aab3441af4f..e5f656ddeb05c85965669c6fe44ac19740a4a9cc 100644 --- a/src/main.ml +++ b/src/main.ml @@ -26,7 +26,6 @@ open Cmdliner let caisar = "caisar" let () = - Native_nn_prover.init (); Nn2smt.init (); Vars_on_lhs.init () diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml new file mode 100644 index 0000000000000000000000000000000000000000..fc1556a281d01d00a09d1608fae82c5fba0367c3 --- /dev/null +++ b/src/proof_strategy.ml @@ -0,0 +1,35 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +let apply_native_nn_prover env task = + let nb = Trans.apply Utils.count_nn_apply task in + match nb with + | 0 -> task + | 1 -> + Trans.( + apply + (seq [ Introduction.introduce_premises; Native_nn_prover.trans env ])) + task + | _ -> + invalid_arg "Two or more neural network applications are not supported yet" diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli new file mode 100644 index 0000000000000000000000000000000000000000..d8a52fded3dab74698e23489ebe35b7d058438d1 --- /dev/null +++ b/src/proof_strategy.mli @@ -0,0 +1,26 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +val apply_native_nn_prover : Env.env -> Task.task -> Task.task +(** Detect and execute applications of neural networks. *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index a7353ebc2b240f579c5f812de60bd86f54a0830c..3edcb00a2093e13e255339d129d018387157ccea 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -92,10 +92,5 @@ let simplify_goal env input_variables = Task.add_decl acc decl) None -let native_nn_prover env = +let trans env = Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] - -let init () = - Trans.register_env_transform - ~desc:"Transformation for provers that support loading neural networks." - "native_nn_prover" native_nn_prover diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 694097b3b0dcc9bef16af63c6f16994dd897e63e..936ff492295cfd2682bca8c47d65365add8fd67e 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,5 +20,6 @@ (* *) (**************************************************************************) -val init : unit -> unit -(** Register the transformation. *) +open Why3 + +val trans : Env.env -> Task.task Trans.trans diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml index efe2607644c792272993476b3248353ba9f69583..1cfbdda2bbcec84da22d9b757b02f1d0106cb395 100644 --- a/src/transformations/utils.ml +++ b/src/transformations/utils.ml @@ -2,13 +2,39 @@ (* *) (* This file is part of CAISAR. *) (* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) (**************************************************************************) open Why3 open Base -(* Retrieve the (input) variables appearing as arguments of a logic symbol - nn_apply. *) +let count_nn_apply = + let rec aux acc (term : Term.term) = + let acc = Term.t_fold aux acc term in + match term.t_node with + | Term.Tapp (ls, _) -> ( + match Language.lookup_loaded_nets ls with + | None -> acc + | Some _ -> acc + 1) + | _ -> acc + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 + let get_input_variables = let rec aux acc (term : Term.term) = match term.t_node with diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli index ecfa4a031606614555a34788fe57e95cc7083033..12b8c1348d0a3763d44c6ef16280c03fc8408dc9 100644 --- a/src/transformations/utils.mli +++ b/src/transformations/utils.mli @@ -2,12 +2,31 @@ (* *) (* This file is part of CAISAR. *) (* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) (**************************************************************************) open Why3 -open Base + +val count_nn_apply : int Trans.trans +(** Count the number of applications of [nn_apply]. *) val get_input_variables : int Term.Mls.t Trans.trans +(** Retrieve the input variables appearing as arguments of [nn_apply]. *) val meta_input : Theory.meta (** Indicate the input position. *) diff --git a/src/verification.ml b/src/verification.ml index fa183c52c83c0b68ba0352589c91cb6dd234692c..bad757960feebdcc54c1eabaf351e5bee0f1dfef 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -254,7 +254,8 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task = let prover_answer = combine_prover_answers answers in (prover_answer, None) -let answer_generic limit config prover config_prover driver task = +let answer_generic limit env config prover config_prover driver task = + let task = Proof_strategy.apply_native_nn_prover env task in let task = Driver.prepare_task driver task in let nn_file = match Task.on_meta_excl Utils.meta_nn_filename task with @@ -286,7 +287,7 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = let dataset_csv = Option.value_exn dataset_csv in answer_on_dataset limit config env config_prover driver dataset_csv task | Marabou | Pyrat | CVC5 -> - answer_generic limit config prover config_prover driver task + answer_generic limit env config prover config_prover driver task in Logs.app (fun m -> m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)