From 77c43c8f7ade2c5bb47b373e10e411f610ff026d Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 11 May 2023 16:12:36 +0200
Subject: [PATCH] [proof_strategy] Strategy for native nn provers splits
 top-level conjunctions in goal formula.

---
 src/proof_strategy.ml  | 54 +++++++++++++++++++-----------------------
 src/proof_strategy.mli |  4 ++--
 src/verification.ml    | 12 +++++-----
 3 files changed, 33 insertions(+), 37 deletions(-)

diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml
index 47047e2..ee69042 100644
--- a/src/proof_strategy.ml
+++ b/src/proof_strategy.ml
@@ -20,9 +20,10 @@
 (*                                                                        *)
 (**************************************************************************)
 
+open Base
 open Why3
 
-let do_count_nn_ls ~lookup =
+let set_of_nn_ls ~lookup sls =
   let rec aux acc term =
     let acc = Term.t_fold aux acc term in
     match term.t_node with
@@ -30,38 +31,33 @@ let do_count_nn_ls ~lookup =
       match lookup ls with None -> acc | Some _ -> Term.Sls.add ls acc)
     | _ -> acc
   in
-  Trans.bind
-    (Trans.fold_decl
-       (fun decl acc -> Decl.decl_fold aux acc decl)
-       Term.Sls.empty)
-    (fun s -> Trans.return (Term.Sls.cardinal s))
+  Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) sls
 
-let apply_classic_prover env task =
-  let nb_nn_apply =
-    let count_nn_apply = do_count_nn_ls ~lookup:Language.lookup_loaded_nets in
-    Trans.apply count_nn_apply task
+let do_apply_prover ~lookup ~trans tasks =
+  let set_nn_ls =
+    List.fold tasks ~init:Term.Sls.empty ~f:(fun accum task ->
+      Trans.apply (set_of_nn_ls ~lookup accum) task)
   in
-  match nb_nn_apply with
-  | 0 -> task
-  | 1 -> Trans.apply (Nn2smt.trans env) task
+  let count_nn_ls = Term.Sls.cardinal set_nn_ls in
+  match count_nn_ls with
+  | 0 -> tasks
+  | 1 -> List.map tasks ~f:(Trans.apply trans)
   | _ ->
     invalid_arg "Two or more neural network applications are not supported yet"
 
+let apply_classic_prover env task =
+  let lookup = Language.lookup_loaded_nets in
+  let trans = Nn2smt.trans env in
+  do_apply_prover ~lookup ~trans [ task ]
+
 let apply_native_nn_prover env task =
-  let nb_nn_applications =
-    let count_nn_applications = do_count_nn_ls ~lookup:Language.lookup_nn in
-    Trans.apply count_nn_applications task
+  let lookup = Language.lookup_nn in
+  let trans =
+    Trans.seq
+      [
+        Introduction.introduce_premises;
+        Native_nn_prover.trans_nn_application env;
+      ]
   in
-  match nb_nn_applications with
-  | 0 -> task
-  | 1 ->
-    Trans.(
-      apply
-        (seq
-           [
-             Introduction.introduce_premises;
-             Native_nn_prover.trans_nn_application env;
-           ]))
-      task
-  | _ ->
-    invalid_arg "Two or more neural network applications are not supported yet"
+  let tasks = Trans.apply Split_goal.split_goal_full task in
+  do_apply_prover ~lookup ~trans tasks
diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli
index fb35660..1ec1f82 100644
--- a/src/proof_strategy.mli
+++ b/src/proof_strategy.mli
@@ -22,8 +22,8 @@
 
 open Why3
 
-val apply_classic_prover : Env.env -> Task.task -> Task.task
+val apply_classic_prover : Env.env -> Task.task -> Task.task list
 (** Detect and translate applications of neural networks into SMT-LIB. *)
 
-val apply_native_nn_prover : Env.env -> Task.task -> Task.task
+val apply_native_nn_prover : Env.env -> Task.task -> Task.task list
 (** Detect and execute applications of neural networks. *)
diff --git a/src/verification.ml b/src/verification.ml
index 9a52272..48f0a47 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -223,8 +223,8 @@ let answer_dataset limit config env prover config_prover driver dataset task =
   in
   (prover_answer, additional_info)
 
-let answer_generic limit config prover config_prover driver task =
-  let tasks = Trans.apply Split_goal.split_goal_full task in
+let answer_generic limit config env prover config_prover driver ~strategy task =
+  let tasks = strategy env task in
   let answers =
     List.concat_map tasks ~f:(fun task ->
       let task = Driver.prepare_task driver task in
@@ -263,11 +263,11 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset task
       answer_dataset limit config env prover config_prover driver dataset task
     | Marabou | Pyrat | Nnenum ->
       let task = Interpretation.interpret_task ~cwd env task in
-      let task = Proof_strategy.apply_native_nn_prover env task in
-      answer_generic limit config prover config_prover driver task
+      let strategy = Proof_strategy.apply_native_nn_prover in
+      answer_generic limit config env prover config_prover driver ~strategy task
     | CVC5 ->
-      let task = Proof_strategy.apply_classic_prover env task in
-      answer_generic limit config prover config_prover driver task
+      let strategy = Proof_strategy.apply_classic_prover in
+      answer_generic limit config env prover config_prover driver ~strategy task
   in
   let id = Task.task_goal task in
   { id; prover_answer; additional_info }
-- 
GitLab