Skip to content
Snippets Groups Projects
Commit ac7d9e18 authored by François Bobot's avatar François Bobot Committed by Michele Alberti
Browse files

Move the native_nn_prover transformation outside drivers.

Make it into a new strategy module that should prepare the task before sending
it to the provers.
parent 6ea82386
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,6 @@ transformation "inline_trivial" ...@@ -34,7 +34,6 @@ transformation "inline_trivial"
transformation "introduce_premises" transformation "introduce_premises"
transformation "eliminate_builtin" transformation "eliminate_builtin"
transformation "simplify_formula" transformation "simplify_formula"
transformation "native_nn_prover"
transformation "vars_on_lhs" transformation "vars_on_lhs"
theory BuiltIn theory BuiltIn
......
...@@ -34,7 +34,6 @@ transformation "inline_trivial" ...@@ -34,7 +34,6 @@ transformation "inline_trivial"
transformation "introduce_premises" transformation "introduce_premises"
transformation "eliminate_builtin" transformation "eliminate_builtin"
transformation "simplify_formula" transformation "simplify_formula"
transformation "native_nn_prover"
theory BuiltIn theory BuiltIn
syntax type int "int" syntax type int "int"
......
...@@ -26,7 +26,6 @@ open Cmdliner ...@@ -26,7 +26,6 @@ open Cmdliner
let caisar = "caisar" let caisar = "caisar"
let () = let () =
Native_nn_prover.init ();
Nn2smt.init (); Nn2smt.init ();
Vars_on_lhs.init () Vars_on_lhs.init ()
......
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
let apply_native_nn_prover env task =
let nb = Trans.apply Utils.count_nn_apply task in
match nb with
| 0 -> task
| 1 ->
Trans.(
apply
(seq [ Introduction.introduce_premises; Native_nn_prover.trans env ]))
task
| _ ->
invalid_arg "Two or more neural network applications are not supported yet"
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
val apply_native_nn_prover : Env.env -> Task.task -> Task.task
(** Detect and execute applications of neural networks. *)
...@@ -92,10 +92,5 @@ let simplify_goal env input_variables = ...@@ -92,10 +92,5 @@ let simplify_goal env input_variables =
Task.add_decl acc decl) Task.add_decl acc decl)
None None
let native_nn_prover env = let trans env =
Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
let init () =
Trans.register_env_transform
~desc:"Transformation for provers that support loading neural networks."
"native_nn_prover" native_nn_prover
...@@ -20,5 +20,6 @@ ...@@ -20,5 +20,6 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
val init : unit -> unit open Why3
(** Register the transformation. *)
val trans : Env.env -> Task.task Trans.trans
...@@ -2,13 +2,39 @@ ...@@ -2,13 +2,39 @@
(* *) (* *)
(* This file is part of CAISAR. *) (* This file is part of CAISAR. *)
(* *) (* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************) (**************************************************************************)
open Why3 open Why3
open Base open Base
(* Retrieve the (input) variables appearing as arguments of a logic symbol let count_nn_apply =
nn_apply. *) let rec aux acc (term : Term.term) =
let acc = Term.t_fold aux acc term in
match term.t_node with
| Term.Tapp (ls, _) -> (
match Language.lookup_loaded_nets ls with
| None -> acc
| Some _ -> acc + 1)
| _ -> acc
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
let get_input_variables = let get_input_variables =
let rec aux acc (term : Term.term) = let rec aux acc (term : Term.term) =
match term.t_node with match term.t_node with
......
...@@ -2,12 +2,31 @@ ...@@ -2,12 +2,31 @@
(* *) (* *)
(* This file is part of CAISAR. *) (* This file is part of CAISAR. *)
(* *) (* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************) (**************************************************************************)
open Why3 open Why3
open Base
val count_nn_apply : int Trans.trans
(** Count the number of applications of [nn_apply]. *)
val get_input_variables : int Term.Mls.t Trans.trans val get_input_variables : int Term.Mls.t Trans.trans
(** Retrieve the input variables appearing as arguments of [nn_apply]. *)
val meta_input : Theory.meta val meta_input : Theory.meta
(** Indicate the input position. *) (** Indicate the input position. *)
......
...@@ -254,7 +254,8 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task = ...@@ -254,7 +254,8 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task =
let prover_answer = combine_prover_answers answers in let prover_answer = combine_prover_answers answers in
(prover_answer, None) (prover_answer, None)
let answer_generic limit config prover config_prover driver task = let answer_generic limit env config prover config_prover driver task =
let task = Proof_strategy.apply_native_nn_prover env task in
let task = Driver.prepare_task driver task in let task = Driver.prepare_task driver task in
let nn_file = let nn_file =
match Task.on_meta_excl Utils.meta_nn_filename task with match Task.on_meta_excl Utils.meta_nn_filename task with
...@@ -286,7 +287,7 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = ...@@ -286,7 +287,7 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
let dataset_csv = Option.value_exn dataset_csv in let dataset_csv = Option.value_exn dataset_csv in
answer_on_dataset limit config env config_prover driver dataset_csv task answer_on_dataset limit config env config_prover driver dataset_csv task
| Marabou | Pyrat | CVC5 -> | Marabou | Pyrat | CVC5 ->
answer_generic limit config prover config_prover driver task answer_generic limit env config prover config_prover driver task
in in
Logs.app (fun m -> Logs.app (fun m ->
m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment