From ac7d9e18ce4f86a37e2996a7eb744848f29bfc12 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Wed, 30 Nov 2022 11:23:10 +0100
Subject: [PATCH] Move the native_nn_prover transformation outside drivers.

Make it into a new strategy module that should prepare the task before sending
it to the provers.
---
 config/drivers/marabou.drv               |  1 -
 config/drivers/pyrat.drv                 |  1 -
 src/main.ml                              |  1 -
 src/proof_strategy.ml                    | 35 ++++++++++++++++++++++++
 src/proof_strategy.mli                   | 26 ++++++++++++++++++
 src/transformations/native_nn_prover.ml  |  7 +----
 src/transformations/native_nn_prover.mli |  5 ++--
 src/transformations/utils.ml             | 30 ++++++++++++++++++--
 src/transformations/utils.mli            | 21 +++++++++++++-
 src/verification.ml                      |  5 ++--
 10 files changed, 116 insertions(+), 16 deletions(-)
 create mode 100644 src/proof_strategy.ml
 create mode 100644 src/proof_strategy.mli

diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv
index dc29a063..25eb4ea3 100644
--- a/config/drivers/marabou.drv
+++ b/config/drivers/marabou.drv
@@ -34,7 +34,6 @@ transformation "inline_trivial"
 transformation "introduce_premises"
 transformation "eliminate_builtin"
 transformation "simplify_formula"
-transformation "native_nn_prover"
 transformation "vars_on_lhs"
 
 theory BuiltIn
diff --git a/config/drivers/pyrat.drv b/config/drivers/pyrat.drv
index d9cc1c9e..26b23930 100644
--- a/config/drivers/pyrat.drv
+++ b/config/drivers/pyrat.drv
@@ -34,7 +34,6 @@ transformation "inline_trivial"
 transformation "introduce_premises"
 transformation "eliminate_builtin"
 transformation "simplify_formula"
-transformation "native_nn_prover"
 
 theory BuiltIn
   syntax type int   "int"
diff --git a/src/main.ml b/src/main.ml
index c59b9740..e5f656dd 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -26,7 +26,6 @@ open Cmdliner
 let caisar = "caisar"
 
 let () =
-  Native_nn_prover.init ();
   Nn2smt.init ();
   Vars_on_lhs.init ()
 
diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml
new file mode 100644
index 00000000..fc1556a2
--- /dev/null
+++ b/src/proof_strategy.ml
@@ -0,0 +1,35 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Why3
+
+let apply_native_nn_prover env task =
+  let nb = Trans.apply Utils.count_nn_apply task in
+  match nb with
+  | 0 -> task
+  | 1 ->
+    Trans.(
+      apply
+        (seq [ Introduction.introduce_premises; Native_nn_prover.trans env ]))
+      task
+  | _ ->
+    invalid_arg "Two or more neural network applications are not supported yet"
diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli
new file mode 100644
index 00000000..d8a52fde
--- /dev/null
+++ b/src/proof_strategy.mli
@@ -0,0 +1,26 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Why3
+
+val apply_native_nn_prover : Env.env -> Task.task -> Task.task
+(** Detect and execute applications of neural networks. *)
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index a7353ebc..3edcb00a 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -92,10 +92,5 @@ let simplify_goal env input_variables =
         Task.add_decl acc decl)
     None
 
-let native_nn_prover env =
+let trans env =
   Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
-
-let init () =
-  Trans.register_env_transform
-    ~desc:"Transformation for provers that support loading neural networks."
-    "native_nn_prover" native_nn_prover
diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli
index 694097b3..936ff492 100644
--- a/src/transformations/native_nn_prover.mli
+++ b/src/transformations/native_nn_prover.mli
@@ -20,5 +20,6 @@
 (*                                                                        *)
 (**************************************************************************)
 
-val init : unit -> unit
-(** Register the transformation. *)
+open Why3
+
+val trans : Env.env -> Task.task Trans.trans
diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml
index efe26076..1cfbdda2 100644
--- a/src/transformations/utils.ml
+++ b/src/transformations/utils.ml
@@ -2,13 +2,39 @@
 (*                                                                        *)
 (*  This file is part of CAISAR.                                          *)
 (*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
 (**************************************************************************)
 
 open Why3
 open Base
 
-(* Retrieve the (input) variables appearing as arguments of a logic symbol
-   nn_apply. *)
+let count_nn_apply =
+  let rec aux acc (term : Term.term) =
+    let acc = Term.t_fold aux acc term in
+    match term.t_node with
+    | Term.Tapp (ls, _) -> (
+      match Language.lookup_loaded_nets ls with
+      | None -> acc
+      | Some _ -> acc + 1)
+    | _ -> acc
+  in
+  Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
+
 let get_input_variables =
   let rec aux acc (term : Term.term) =
     match term.t_node with
diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli
index ecfa4a03..12b8c134 100644
--- a/src/transformations/utils.mli
+++ b/src/transformations/utils.mli
@@ -2,12 +2,31 @@
 (*                                                                        *)
 (*  This file is part of CAISAR.                                          *)
 (*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
 (**************************************************************************)
 
 open Why3
-open Base
+
+val count_nn_apply : int Trans.trans
+(** Count the number of applications of [nn_apply]. *)
 
 val get_input_variables : int Term.Mls.t Trans.trans
+(** Retrieve the input variables appearing as arguments of [nn_apply]. *)
 
 val meta_input : Theory.meta
 (** Indicate the input position. *)
diff --git a/src/verification.ml b/src/verification.ml
index fa183c52..bad75796 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -254,7 +254,8 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task =
   let prover_answer = combine_prover_answers answers in
   (prover_answer, None)
 
-let answer_generic limit config prover config_prover driver task =
+let answer_generic limit env config prover config_prover driver task =
+  let task = Proof_strategy.apply_native_nn_prover env task in
   let task = Driver.prepare_task driver task in
   let nn_file =
     match Task.on_meta_excl Utils.meta_nn_filename task with
@@ -286,7 +287,7 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
       let dataset_csv = Option.value_exn dataset_csv in
       answer_on_dataset limit config env config_prover driver dataset_csv task
     | Marabou | Pyrat | CVC5 ->
-      answer_generic limit config prover config_prover driver task
+      answer_generic limit env config prover config_prover driver task
   in
   Logs.app (fun m ->
     m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
-- 
GitLab