From ffaac063dfd5c7104248a936cf0512ffcc0ef0d7 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Mon, 22 Nov 2021 22:24:33 +0100
Subject: [PATCH] WIP: Printer for marabou.

---
 config/caisar-detection-data.conf       |   2 +-
 config/drivers/marabou.drv              | 187 ++++++++++++++++++++++++
 config/dune                             |   1 +
 src/main.ml                             |   8 +-
 src/printers/marabou.ml                 | 110 ++++++++++++++
 src/transformations/native_nn_prover.ml |  42 +++---
 src/transformations/vars_on_lhs.ml      |  46 ++++++
 src/verification.ml                     |   1 +
 8 files changed, 374 insertions(+), 23 deletions(-)
 create mode 100644 src/printers/marabou.ml
 create mode 100644 src/transformations/vars_on_lhs.ml

diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf
index e4b4b2a9..9b73c49d 100644
--- a/config/caisar-detection-data.conf
+++ b/config/caisar-detection-data.conf
@@ -16,7 +16,7 @@ exec = "Marabou"
 version_switch = "--version"
 version_regexp = "\\([0-9.+]+\\)"
 version_ok  = "1.0.+"
-command = "%e --timeout %t %f"
+command = "%e %{nnet-onnx} %f"
 driver = "caisar_drivers/marabou.drv"
 use_at_auto_level = 1
 
diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv
index 8a2487dd..dddc85b8 100644
--- a/config/drivers/marabou.drv
+++ b/config/drivers/marabou.drv
@@ -1 +1,188 @@
 (* Why3 drivers for Marabou *)
+
+printer "marabou"
+filename "%f-%t-%g.why"
+
+valid "^[Ss]at"
+invalid "^[Uu]nsat"
+timeout "^[Tt]imeout"
+unknown "^[Uu]nknown" ""
+
+transformation "inline_trivial"
+transformation "introduce_premises"
+transformation "eliminate_builtin"
+transformation "simplify_formula"
+transformation "native_nn_prover"
+transformation "vars_on_lhs"
+
+theory BuiltIn
+  syntax type int   "int"
+  syntax type real  "real"
+
+  syntax predicate (=)  "(%1 = %2)"
+
+  meta "eliminate_algebraic" "keep_enums"
+  meta "eliminate_algebraic" "keep_recs"
+
+end
+
+theory int.Int
+
+  prelude "(* this is a prelude for Alt-Ergo integer arithmetic *)"
+
+  syntax function zero "0"
+  syntax function one  "1"
+
+  syntax function (+)  "(%1 + %2)"
+  syntax function (-)  "(%1 - %2)"
+  syntax function (*)  "(%1 * %2)"
+  syntax function (-_) "(-%1)"
+
+  meta "invalid trigger" predicate (<=)
+  meta "invalid trigger" predicate (<)
+  meta "invalid trigger" predicate (>=)
+  meta "invalid trigger" predicate (>)
+
+  syntax predicate (<=) "(%1 <= %2)"
+  syntax predicate (<)  "(%1 <  %2)"
+  syntax predicate (>=) "(%1 >= %2)"
+  syntax predicate (>)  "(%1 >  %2)"
+
+  remove prop MulComm.Comm
+  remove prop MulAssoc.Assoc
+  remove prop Unit_def_l
+  remove prop Unit_def_r
+  remove prop Inv_def_l
+  remove prop Inv_def_r
+  remove prop Assoc
+  remove prop Mul_distr_l
+  remove prop Mul_distr_r
+  remove prop Comm
+  remove prop Unitary
+  remove prop Refl
+  remove prop Trans
+  remove prop Total
+  remove prop Antisymm
+  remove prop NonTrivialRing
+  remove prop CompatOrderAdd
+  remove prop ZeroLessOne
+
+end
+
+theory int.EuclideanDivision
+
+  syntax function div "(%1 / %2)"
+  syntax function mod "(%1 % %2)"
+
+end
+
+theory int.ComputerDivision
+
+  use for_drivers.ComputerOfEuclideanDivision
+
+end
+
+
+theory real.Real
+
+  prelude "(* this is a prelude for Alt-Ergo real arithmetic *)"
+
+  syntax function zero "0.0"
+  syntax function one  "1.0"
+
+  syntax function (+)  "(%1 + %2)"
+  syntax function (-)  "(%1 - %2)"
+  syntax function (*)  "(%1 * %2)"
+  syntax function (/)  "(%1 / %2)"
+  syntax function (-_) "(-%1)"
+  syntax function inv  "(1.0 / %1)"
+
+  meta "invalid trigger" predicate (<=)
+  meta "invalid trigger" predicate (<)
+  meta "invalid trigger" predicate (>=)
+  meta "invalid trigger" predicate (>)
+
+  syntax predicate (<=) "(%1 <= %2)"
+  syntax predicate (<)  "(%1 <  %2)"
+  syntax predicate (>=) "(%1 >= %2)"
+  syntax predicate (>)  "(%1 >  %2)"
+
+  remove prop MulComm.Comm
+  remove prop MulAssoc.Assoc
+  remove prop Unit_def_l
+  remove prop Unit_def_r
+  remove prop Inv_def_l
+  remove prop Inv_def_r
+  remove prop Assoc
+  remove prop Mul_distr_l
+  remove prop Mul_distr_r
+  remove prop Comm
+  remove prop Unitary
+  remove prop Refl
+  remove prop Trans
+  remove prop Total
+  remove prop Antisymm
+  remove prop Inverse
+  remove prop NonTrivialRing
+  remove prop CompatOrderAdd
+  remove prop ZeroLessOne
+
+end
+
+theory ieee_float.Float64
+
+  syntax function (.+)  "(%1 + %2)"
+  syntax function (.-)  "(%1 - %2)"
+  syntax function (.*)  "(%1 * %2)"
+  syntax function (./)  "(%1 / %2)"
+  syntax function (.-_) "(-%1)"
+
+  syntax predicate le "%1 <= %2"
+  syntax predicate lt  "%1 <  %2"
+  syntax predicate ge "%1 >= %2"
+  syntax predicate gt  "%1 >  %2"
+
+
+end
+
+theory real.RealInfix
+
+  syntax function (+.)  "(%1 + %2)"
+  syntax function (-.)  "(%1 - %2)"
+  syntax function ( *.) "(%1 * %2)"
+  syntax function (/.)  "(%1 / %2)"
+  syntax function (-._) "(-%1)"
+
+  meta "invalid trigger" predicate (<=.)
+  meta "invalid trigger" predicate (<.)
+  meta "invalid trigger" predicate (>=.)
+  meta "invalid trigger" predicate (>.)
+
+  syntax predicate (<=.) "(%1 <= %2)"
+  syntax predicate (<.)  "(%1 <  %2)"
+  syntax predicate (>=.) "(%1 >= %2)"
+  syntax predicate (>.)  "(%1 >  %2)"
+
+end
+
+theory Bool
+  syntax type     bool  "bool"
+  syntax function True  "true"
+  syntax function False "false"
+end
+
+theory Tuple0
+  syntax type     tuple0 "unit"
+  syntax function Tuple0 "void"
+end
+
+theory algebra.AC
+  meta AC function op
+  remove prop Comm
+  remove prop Assoc
+end
+
+theory ieee_float.Float64
+  syntax predicate is_not_nan ""
+  remove allprops
+end
diff --git a/config/dune b/config/dune
index bd74fc9e..e73ccd82 100644
--- a/config/dune
+++ b/config/dune
@@ -2,5 +2,6 @@
  (section (site (caisar config)))
  (files caisar-detection-data.conf
         (drivers/pyrat.drv as drivers/pyrat.drv)
+        (drivers/marabou.drv as drivers/marabou.drv)
  )
 (package caisar))
diff --git a/src/main.ml b/src/main.ml
index 0894eb3d..a86c5a94 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -9,9 +9,13 @@ open Cmdliner
 
 let caisar = "caisar"
 
-let () = Native_nn_prover.init ()
+let () =
+  Native_nn_prover.init ();
+  Vars_on_lhs.init ()
 
-let () = Pyrat.init ()
+let () =
+  Pyrat.init ();
+  Marabou.init ()
 
 (* -- Logs. *)
 
diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml
new file mode 100644
index 00000000..28a6c84d
--- /dev/null
+++ b/src/printers/marabou.ml
@@ -0,0 +1,110 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(**************************************************************************)
+
+type info = {
+  info_syn : Why3.Printer.syntax_map;
+  variables : string Why3.Term.Hls.t;
+}
+
+let number_format =
+  {
+    Why3.Number.long_int_support = `Default;
+    Why3.Number.negative_int_support = `Default;
+    Why3.Number.dec_int_support = `Default;
+    Why3.Number.hex_int_support = `Unsupported;
+    Why3.Number.oct_int_support = `Unsupported;
+    Why3.Number.bin_int_support = `Unsupported;
+    Why3.Number.negative_real_support = `Default;
+    Why3.Number.dec_real_support = `Default;
+    Why3.Number.hex_real_support = `Unsupported;
+    Why3.Number.frac_real_support = `Unsupported (fun _ _ -> assert false);
+  }
+
+let rec print_term info fmt t =
+  let open Why3 in
+  match t.Term.t_node with
+  | Tbinop ((Timplies | Tiff | Tor), _, _)
+  | Tnot _ | Ttrue | Tfalse | Tvar _ | Tlet _ | Tif _ | Tcase _ | Tquant _
+  | Teps _ ->
+    Printer.unsupportedTerm t "Not supported by Marabou"
+  | Tbinop (Tand, _, _) -> assert false (* Should appear only at top-level. *)
+  | Tconst c -> Constant.(print number_format unsupported_escape) fmt c
+  | Tapp (ls, l) -> (
+    match Printer.query_syntax info.info_syn ls.ls_name with
+    | Some s -> Printer.syntax_arguments s (print_term info) fmt l
+    | None -> (
+      match (Term.Hls.find_opt info.variables ls, l) with
+      | Some s, [] -> Fmt.string fmt s
+      | _ -> Printer.unsupportedTerm t "Unknown variable(s)"))
+
+let rec print_top_level_term info fmt t =
+  let open Why3 in
+  (* Don't print things we don't know. *)
+  let t_is_known =
+    Term.t_s_all
+      (fun _ -> true)
+      (fun ls ->
+        Ident.Mid.mem ls.ls_name info.info_syn || Term.Hls.mem info.variables ls)
+  in
+  match t.Term.t_node with
+  | Tquant _ -> ()
+  | Tbinop (Tand, t1, t2) ->
+    if t_is_known t1 && t_is_known t2
+    then
+      Fmt.pf fmt "%a%a"
+        (print_top_level_term info)
+        t1
+        (print_top_level_term info)
+        t2
+  | _ -> if t_is_known t then Fmt.pf fmt "%a@." (print_term info) t
+
+let print_decl info fmt d =
+  let open Why3 in
+  match d.Decl.d_node with
+  | Dtype _ -> ()
+  | Ddata _ -> ()
+  | Dparam _ -> ()
+  | Dlogic _ -> ()
+  | Dind _ -> ()
+  | Dprop (Decl.Plemma, _, _) -> assert false
+  | Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f
+  | Dprop (Decl.Pgoal, _, f) -> print_top_level_term info fmt f
+
+let rec print_tdecl info fmt task =
+  let open Why3 in
+  match task with
+  | None -> ()
+  | Some { Task.task_prev; task_decl; _ } -> (
+    print_tdecl info fmt task_prev;
+    match task_decl.Theory.td_node with
+    | Use _ | Clone _ -> ()
+    | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input
+      -> (
+      match l with
+      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
+      | _ -> assert false)
+    | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output
+      -> (
+      match l with
+      | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
+      | _ -> assert false)
+    | Meta (_, _) -> ()
+    | Decl d -> print_decl info fmt d)
+
+let print_task args ?old:_ fmt task =
+  let open Why3 in
+  let info =
+    {
+      info_syn = Discriminate.get_syntax_map task;
+      variables = Term.Hls.create 10;
+    }
+  in
+  Printer.print_prelude fmt args.Printer.prelude;
+  print_tdecl info fmt task
+
+let init () =
+  Why3.Printer.register_printer ~desc:"Printer for the Marabou prover."
+    "marabou" print_task
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index fd3a3800..95b89b07 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -78,27 +78,29 @@ let simplify_goal env input_variables =
       | Decl decl ->
         let meta = ref [] in
         let hls = ref [] in
-        let map term =
-          let term = aux meta hls term in
-          if List.is_empty !hls
-          then term
-          else
-            let known =
-              List.fold !hls ~init:task_hd.task_known ~f:(fun acc (d, _, _) ->
-                Decl.known_add_decl acc d)
-            in
-            let engine =
-              Reduction_engine.create
-                {
-                  compute_defs = false;
-                  compute_builtin = true;
-                  compute_def_set = Term.Sls.empty;
-                }
-                env known
-            in
-            Reduction_engine.normalize ~limit:100 engine Term.Mvs.empty term
+        let decl =
+          Decl.decl_map
+            (fun term ->
+              let term = aux meta hls term in
+              if List.is_empty !hls
+              then term
+              else
+                let known =
+                  List.fold !hls ~init:task_hd.task_known
+                    ~f:(fun acc (d, _, _) -> Decl.known_add_decl acc d)
+                in
+                let engine =
+                  Reduction_engine.create
+                    {
+                      compute_defs = false;
+                      compute_builtin = true;
+                      compute_def_set = Term.Sls.empty;
+                    }
+                    env known
+                in
+                Reduction_engine.normalize ~limit:100 engine Term.Mvs.empty term)
+            decl
         in
-        let decl = Decl.decl_map map decl in
         let acc =
           List.fold !hls ~init:acc ~f:(fun acc (d, ls, i) ->
             let task = Task.add_decl acc d in
diff --git a/src/transformations/vars_on_lhs.ml b/src/transformations/vars_on_lhs.ml
new file mode 100644
index 00000000..021ac66a
--- /dev/null
+++ b/src/transformations/vars_on_lhs.ml
@@ -0,0 +1,46 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(**************************************************************************)
+
+let make_rt env =
+  let open Why3 in
+  let th = Env.read_theory env [ "ieee_float" ] "Float64" in
+  let le = Theory.ns_find_ls th.th_export [ "le" ] in
+  let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
+  let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
+  let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
+  let rec rt t =
+    let t = Term.t_map rt t in
+    match t.t_node with
+    | Tapp
+        ( ls,
+          [
+            ({ t_node = Tconst _; _ } as const);
+            ({ t_node = Tapp (_, []); _ } as var);
+          ] ) ->
+      Fmt.pr "HERE: %a!@." Pretty.print_term t;
+      if Term.ls_equal ls le
+      then Term.ps_app ge [ var; const ]
+      else if Term.ls_equal ls ge
+      then Term.ps_app le [ var; const ]
+      else if Term.ls_equal ls lt
+      then Term.ps_app gt [ var; const ]
+      else if Term.ls_equal ls gt
+      then Term.ps_app lt [ var; const ]
+      else t
+    | _ -> t
+  in
+  rt
+
+let vars_on_lhs env =
+  let rt = make_rt env in
+  Why3.Trans.rewrite rt None
+
+let init () =
+  Why3.Trans.register_env_transform
+    ~desc:
+      "Transformation for provers that need variables on the left-hand-side of \
+       logic symbols."
+    "vars_on_lhs" vars_on_lhs
diff --git a/src/verification.ml b/src/verification.ml
index 1aeefd3a..cb38ab34 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -49,6 +49,7 @@ let call_prover ~limit prover driver task =
 
 let verify format loadpath ?memlimit ~prover file =
   let open Why3 in
+  (* Debug.(set_flag (lookup_flag "call_prover")); *)
   let env, config = create_env loadpath in
   let steplimit = None in
   let timeout = None in
-- 
GitLab