From b3c28279d53e4984416f716a1bc6f2003a40b7a1 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 30 Nov 2022 17:36:16 +0100
Subject: [PATCH] [verification][printers] Allow Marabou to be launched on a
 dataset verification.

Modify the printer and transformations for Marabou to accept working on reals.
---
 config/drivers/marabou.drv         |  8 +--
 src/printers/marabou.ml            | 88 ++++++++++++++++++++----------
 src/transformations/vars_on_lhs.ml | 54 +++++++++++++-----
 src/verification.ml                | 31 ++++++++---
 4 files changed, 127 insertions(+), 54 deletions(-)

diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv
index 25eb4ea3..098c2aa6 100644
--- a/config/drivers/marabou.drv
+++ b/config/drivers/marabou.drv
@@ -123,10 +123,10 @@ theory real.Real
   meta "invalid trigger" predicate (>=)
   meta "invalid trigger" predicate (>)
 
-  syntax predicate (<=) "(%1 <= %2)"
-  syntax predicate (<)  "(%1 <  %2)"
-  syntax predicate (>=) "(%1 >= %2)"
-  syntax predicate (>)  "(%1 >  %2)"
+  syntax predicate (<=) "%1 <= %2"
+  syntax predicate (<)  "%1 < %2"
+  syntax predicate (>=) "%1 >= %2"
+  syntax predicate (>)  "%1 > %2"
 
   remove prop MulComm.Comm
   remove prop MulAssoc.Assoc
diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml
index 4a58c113..67fc24f9 100644
--- a/src/printers/marabou.ml
+++ b/src/printers/marabou.ml
@@ -22,11 +22,16 @@
 
 open Why3
 
-type info = {
+type relops = {
   le : Term.lsymbol;
   ge : Term.lsymbol;
   lt : Term.lsymbol;
   gt : Term.lsymbol;
+}
+
+type info = {
+  ls_rel_real : relops;
+  ls_rel_float : relops;
   info_syn : Printer.syntax_map;
   variables : string Term.Hls.t;
 }
@@ -70,13 +75,17 @@ let rec print_term info fmt t =
         Term.Hls.find_opt info.variables ls2 )
     with
     | Some s1, Some s2 ->
-      if Term.ls_equal ls info.le
+      if Term.ls_equal ls info.ls_rel_float.le
+         || Term.ls_equal ls info.ls_rel_real.le
       then Fmt.pf fmt "+%s -%s <= 0" s1 s2
-      else if Term.ls_equal ls info.ge
+      else if Term.ls_equal ls info.ls_rel_float.ge
+              || Term.ls_equal ls info.ls_rel_real.ge
       then Fmt.pf fmt "+%s -%s >= 0" s1 s2
-      else if Term.ls_equal ls info.lt
+      else if Term.ls_equal ls info.ls_rel_float.lt
+              || Term.ls_equal ls info.ls_rel_real.lt
       then Fmt.pf fmt "+%s -%s < 0" s1 s2
-      else if Term.ls_equal ls info.gt
+      else if Term.ls_equal ls info.ls_rel_float.gt
+              || Term.ls_equal ls info.ls_rel_real.gt
       then Fmt.pf fmt "+%s -%s > 0" s1 s2
       else Printer.unsupportedTerm t "Unknown relational operator"
     | _ -> Printer.unsupportedTerm t "Unknown variable(s)")
@@ -116,24 +125,38 @@ let rec negate_term info t =
     Term.t_and (negate_term info t1) (negate_term info t2)
   | Tapp (ls, [ t1; t2 ]) ->
     let tt = [ t1; t2 ] in
-    if Term.ls_equal ls info.le
-    then Term.ps_app info.gt tt
-    else if Term.ls_equal ls info.ge
-    then Term.ps_app info.lt tt
-    else if Term.ls_equal ls info.lt
-    then Term.ps_app info.ge tt
-    else if Term.ls_equal ls info.gt
-    then Term.ps_app info.le tt
-    else Printer.unsupportedTerm t "Cannot negate such term"
+    (* Negate float relational symbols. *)
+    let ls_neg =
+      if Term.ls_equal ls info.ls_rel_float.le
+      then info.ls_rel_float.gt
+      else if Term.ls_equal ls info.ls_rel_float.ge
+      then info.ls_rel_float.lt
+      else if Term.ls_equal ls info.ls_rel_float.lt
+      then info.ls_rel_float.ge
+      else if Term.ls_equal ls info.ls_rel_float.gt
+      then info.ls_rel_float.le
+      else ls
+    in
+    (* Negate real relational symbols. *)
+    let ls_neg =
+      if Term.ls_equal ls info.ls_rel_real.le
+      then info.ls_rel_real.gt
+      else if Term.ls_equal ls info.ls_rel_real.ge
+      then info.ls_rel_real.lt
+      else if Term.ls_equal ls info.ls_rel_real.lt
+      then info.ls_rel_real.ge
+      else if Term.ls_equal ls info.ls_rel_real.gt
+      then info.ls_rel_real.le
+      else ls_neg
+    in
+    if Term.ls_equal ls_neg ls
+    then Printer.unsupportedTerm t "Cannot negate such term"
+    else Term.ps_app ls_neg tt
   | _ -> Printer.unsupportedTerm t "Cannot negate such term"
 
 let print_decl info fmt d =
   match d.Decl.d_node with
-  | Dtype _ -> ()
-  | Ddata _ -> ()
-  | Dparam _ -> ()
-  | Dlogic _ -> ()
-  | Dind _ -> ()
+  | Dtype _ | Ddata _ | Dparam _ | Dlogic _ | Dind _ -> ()
   | Dprop (Decl.Plemma, _, _) -> assert false
   | Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f
   | Dprop (Decl.Pgoal, _, f) ->
@@ -158,17 +181,26 @@ let rec print_tdecl info fmt task =
     | Decl d -> print_decl info fmt d)
 
 let print_task args ?old:_ fmt task =
-  let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in
-  let le = Theory.ns_find_ls th.th_export [ "le" ] in
-  let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
-  let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
-  let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
+  let ls_rel_real =
+    let th = Env.read_theory args.Printer.env [ "real" ] "Real" in
+    let le = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in
+    let lt = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in
+    let ge = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in
+    let gt = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in
+    { le; ge; lt; gt }
+  in
+  let ls_rel_float =
+    let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in
+    let le = Theory.ns_find_ls th.th_export [ "le" ] in
+    let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
+    let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
+    let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
+    { le; ge; lt; gt }
+  in
   let info =
     {
-      le;
-      lt;
-      ge;
-      gt;
+      ls_rel_real;
+      ls_rel_float;
       info_syn = Discriminate.get_syntax_map task;
       variables = Term.Hls.create 10;
     }
diff --git a/src/transformations/vars_on_lhs.ml b/src/transformations/vars_on_lhs.ml
index 6bed9204..909a4087 100644
--- a/src/transformations/vars_on_lhs.ml
+++ b/src/transformations/vars_on_lhs.ml
@@ -26,10 +26,15 @@ open Base
 let make_rt env =
   let th = Env.read_theory env [ "ieee_float" ] "Float64" in
   let t = Theory.ns_find_ts th.th_export [ "t" ] in
-  let le = Theory.ns_find_ls th.th_export [ "le" ] in
-  let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
-  let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
-  let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
+  let le_float = Theory.ns_find_ls th.th_export [ "le" ] in
+  let lt_float = Theory.ns_find_ls th.th_export [ "lt" ] in
+  let ge_float = Theory.ns_find_ls th.th_export [ "ge" ] in
+  let gt_float = Theory.ns_find_ls th.th_export [ "gt" ] in
+  let th = Env.read_theory env [ "real" ] "Real" in
+  let le_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in
+  let lt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in
+  let ge_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in
+  let gt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in
   let rec rt t =
     let t = Term.t_map rt t in
     match t.t_node with
@@ -40,19 +45,40 @@ let make_rt env =
             ({ t_node = Tapp (_, []); _ } as var);
           ] ) ->
       let tt = [ var; const ] in
-      if Term.ls_equal ls le
-      then Term.ps_app ge tt
-      else if Term.ls_equal ls ge
-      then Term.ps_app le tt
-      else if Term.ls_equal ls lt
-      then Term.ps_app gt tt
-      else if Term.ls_equal ls gt
-      then Term.ps_app lt tt
-      else t
+      let ls_rel =
+        if Term.ls_equal ls le_float
+        then ge_float
+        else if Term.ls_equal ls ge_float
+        then le_float
+        else if Term.ls_equal ls lt_float
+        then gt_float
+        else if Term.ls_equal ls gt_float
+        then lt_float
+        else ls
+      in
+      let ls_rel =
+        if Term.ls_equal ls le_real
+        then ge_real
+        else if Term.ls_equal ls ge_real
+        then le_real
+        else if Term.ls_equal ls lt_real
+        then gt_real
+        else if Term.ls_equal ls gt_real
+        then lt_real
+        else ls_rel
+      in
+      if Term.ls_equal ls_rel ls then t else Term.ps_app ls_rel tt
     | _ -> t
   in
   let task =
-    List.fold [ le; lt; ge; gt ] ~init:(Task.add_ty_decl None t)
+    List.fold
+      [ le_float; lt_float; ge_float; gt_float ]
+      ~init:(Task.add_ty_decl None t) ~f:Task.add_param_decl
+  in
+  let task =
+    List.fold
+      [ le_real; lt_real; ge_real; gt_real ]
+      ~init:(Task.add_ty_decl task Ty.ts_real)
       ~f:Task.add_param_decl
   in
   (rt, task)
diff --git a/src/verification.ml b/src/verification.ml
index bad75796..96dba8fd 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -223,7 +223,8 @@ let combine_prover_answers answers =
     | Call_provers.Valid, r | r, Call_provers.Valid -> r
     | _ -> acc)
 
-let answer_on_dataset limit config env config_prover driver dataset_csv task =
+let answer_on_dataset limit config env prover config_prover driver dataset_csv
+  task =
   let dataset_predicate =
     let on_model ls =
       let message =
@@ -239,22 +240,35 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task =
     in
     Dataset.interpret_predicate env ~on_model ~on_dataset task
   in
-  let tasks =
+  let dataset_tasks =
+    (* One task for dataset element. *)
     Dataset.tasks_of_nn_csv_predicate env dataset_predicate
     |> List.map ~f:(Driver.prepare_task driver)
   in
+  let tasks =
+    (* We turn each task in [dataset_tasks] into a list (ie, conjunction) of
+       disjunctions of tasks. *)
+    match prover with
+    | Prover.Marabou -> List.map ~f:(Trans.apply Split.split_all) dataset_tasks
+    | _ -> [ dataset_tasks ]
+  in
   let command = Whyconf.get_complete_command ~with_steps:false config_prover in
   let command =
     let nn_file = Unix.realpath dataset_predicate.model.filename in
     Re__Core.replace_string nnet_or_onnx ~by:nn_file command
   in
   let answers =
-    List.map tasks ~f:(call_prover_on_task limit config command driver)
+    List.map tasks ~f:(fun dataset_elt_tasks ->
+      let dataset_elt_answers =
+        List.map dataset_elt_tasks
+          ~f:(call_prover_on_task limit config command driver)
+      in
+      combine_prover_answers dataset_elt_answers)
   in
   let prover_answer = combine_prover_answers answers in
   (prover_answer, None)
 
-let answer_generic limit env config prover config_prover driver task =
+let answer_generic limit config env prover config_prover driver task =
   let task = Proof_strategy.apply_native_nn_prover env task in
   let task = Driver.prepare_task driver task in
   let nn_file =
@@ -264,7 +278,7 @@ let answer_generic limit env config prover config_prover driver task =
     | None -> invalid_arg "No neural network model found in task"
   in
   let tasks =
-    (* We make [tasks] as a list (ie, conjunction) of disjunctions. *)
+    (* We turn [task] into a list (ie, conjunction) of disjunctions of tasks. *)
     match prover with
     | Prover.Marabou -> Trans.apply Split.split_all task
     | Pyrat -> Trans.apply Split.split_premises task
@@ -283,11 +297,12 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
     match prover with
     | Prover.Saver ->
       answer_saver limit config env config_prover dataset_csv task
-    | Pyrat when Option.is_some dataset_csv ->
+    | (Marabou | Pyrat) when Option.is_some dataset_csv ->
       let dataset_csv = Option.value_exn dataset_csv in
-      answer_on_dataset limit config env config_prover driver dataset_csv task
+      answer_on_dataset limit config env prover config_prover driver dataset_csv
+        task
     | Marabou | Pyrat | CVC5 ->
-      answer_generic limit env config prover config_prover driver task
+      answer_generic limit config env prover config_prover driver task
   in
   Logs.app (fun m ->
     m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
-- 
GitLab