From ffaac063dfd5c7104248a936cf0512ffcc0ef0d7 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Mon, 22 Nov 2021 22:24:33 +0100 Subject: [PATCH] WIP: Printer for marabou. --- config/caisar-detection-data.conf | 2 +- config/drivers/marabou.drv | 187 ++++++++++++++++++++++++ config/dune | 1 + src/main.ml | 8 +- src/printers/marabou.ml | 110 ++++++++++++++ src/transformations/native_nn_prover.ml | 42 +++--- src/transformations/vars_on_lhs.ml | 46 ++++++ src/verification.ml | 1 + 8 files changed, 374 insertions(+), 23 deletions(-) create mode 100644 src/printers/marabou.ml create mode 100644 src/transformations/vars_on_lhs.ml diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index e4b4b2a9..9b73c49d 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -16,7 +16,7 @@ exec = "Marabou" version_switch = "--version" version_regexp = "\\([0-9.+]+\\)" version_ok = "1.0.+" -command = "%e --timeout %t %f" +command = "%e %{nnet-onnx} %f" driver = "caisar_drivers/marabou.drv" use_at_auto_level = 1 diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv index 8a2487dd..dddc85b8 100644 --- a/config/drivers/marabou.drv +++ b/config/drivers/marabou.drv @@ -1 +1,188 @@ (* Why3 drivers for Marabou *) + +printer "marabou" +filename "%f-%t-%g.why" + +valid "^[Ss]at" +invalid "^[Uu]nsat" +timeout "^[Tt]imeout" +unknown "^[Uu]nknown" "" + +transformation "inline_trivial" +transformation "introduce_premises" +transformation "eliminate_builtin" +transformation "simplify_formula" +transformation "native_nn_prover" +transformation "vars_on_lhs" + +theory BuiltIn + syntax type int "int" + syntax type real "real" + + syntax predicate (=) "(%1 = %2)" + + meta "eliminate_algebraic" "keep_enums" + meta "eliminate_algebraic" "keep_recs" + +end + +theory int.Int + + prelude "(* this is a prelude for Alt-Ergo integer arithmetic *)" + + syntax function zero "0" + syntax function one "1" + + syntax function (+) "(%1 + %2)" + syntax function (-) "(%1 - %2)" + syntax function (*) "(%1 * %2)" + syntax function (-_) "(-%1)" + + meta "invalid trigger" predicate (<=) + meta "invalid trigger" predicate (<) + meta "invalid trigger" predicate (>=) + meta "invalid trigger" predicate (>) + + syntax predicate (<=) "(%1 <= %2)" + syntax predicate (<) "(%1 < %2)" + syntax predicate (>=) "(%1 >= %2)" + syntax predicate (>) "(%1 > %2)" + + remove prop MulComm.Comm + remove prop MulAssoc.Assoc + remove prop Unit_def_l + remove prop Unit_def_r + remove prop Inv_def_l + remove prop Inv_def_r + remove prop Assoc + remove prop Mul_distr_l + remove prop Mul_distr_r + remove prop Comm + remove prop Unitary + remove prop Refl + remove prop Trans + remove prop Total + remove prop Antisymm + remove prop NonTrivialRing + remove prop CompatOrderAdd + remove prop ZeroLessOne + +end + +theory int.EuclideanDivision + + syntax function div "(%1 / %2)" + syntax function mod "(%1 % %2)" + +end + +theory int.ComputerDivision + + use for_drivers.ComputerOfEuclideanDivision + +end + + +theory real.Real + + prelude "(* this is a prelude for Alt-Ergo real arithmetic *)" + + syntax function zero "0.0" + syntax function one "1.0" + + syntax function (+) "(%1 + %2)" + syntax function (-) "(%1 - %2)" + syntax function (*) "(%1 * %2)" + syntax function (/) "(%1 / %2)" + syntax function (-_) "(-%1)" + syntax function inv "(1.0 / %1)" + + meta "invalid trigger" predicate (<=) + meta "invalid trigger" predicate (<) + meta "invalid trigger" predicate (>=) + meta "invalid trigger" predicate (>) + + syntax predicate (<=) "(%1 <= %2)" + syntax predicate (<) "(%1 < %2)" + syntax predicate (>=) "(%1 >= %2)" + syntax predicate (>) "(%1 > %2)" + + remove prop MulComm.Comm + remove prop MulAssoc.Assoc + remove prop Unit_def_l + remove prop Unit_def_r + remove prop Inv_def_l + remove prop Inv_def_r + remove prop Assoc + remove prop Mul_distr_l + remove prop Mul_distr_r + remove prop Comm + remove prop Unitary + remove prop Refl + remove prop Trans + remove prop Total + remove prop Antisymm + remove prop Inverse + remove prop NonTrivialRing + remove prop CompatOrderAdd + remove prop ZeroLessOne + +end + +theory ieee_float.Float64 + + syntax function (.+) "(%1 + %2)" + syntax function (.-) "(%1 - %2)" + syntax function (.*) "(%1 * %2)" + syntax function (./) "(%1 / %2)" + syntax function (.-_) "(-%1)" + + syntax predicate le "%1 <= %2" + syntax predicate lt "%1 < %2" + syntax predicate ge "%1 >= %2" + syntax predicate gt "%1 > %2" + + +end + +theory real.RealInfix + + syntax function (+.) "(%1 + %2)" + syntax function (-.) "(%1 - %2)" + syntax function ( *.) "(%1 * %2)" + syntax function (/.) "(%1 / %2)" + syntax function (-._) "(-%1)" + + meta "invalid trigger" predicate (<=.) + meta "invalid trigger" predicate (<.) + meta "invalid trigger" predicate (>=.) + meta "invalid trigger" predicate (>.) + + syntax predicate (<=.) "(%1 <= %2)" + syntax predicate (<.) "(%1 < %2)" + syntax predicate (>=.) "(%1 >= %2)" + syntax predicate (>.) "(%1 > %2)" + +end + +theory Bool + syntax type bool "bool" + syntax function True "true" + syntax function False "false" +end + +theory Tuple0 + syntax type tuple0 "unit" + syntax function Tuple0 "void" +end + +theory algebra.AC + meta AC function op + remove prop Comm + remove prop Assoc +end + +theory ieee_float.Float64 + syntax predicate is_not_nan "" + remove allprops +end diff --git a/config/dune b/config/dune index bd74fc9e..e73ccd82 100644 --- a/config/dune +++ b/config/dune @@ -2,5 +2,6 @@ (section (site (caisar config))) (files caisar-detection-data.conf (drivers/pyrat.drv as drivers/pyrat.drv) + (drivers/marabou.drv as drivers/marabou.drv) ) (package caisar)) diff --git a/src/main.ml b/src/main.ml index 0894eb3d..a86c5a94 100644 --- a/src/main.ml +++ b/src/main.ml @@ -9,9 +9,13 @@ open Cmdliner let caisar = "caisar" -let () = Native_nn_prover.init () +let () = + Native_nn_prover.init (); + Vars_on_lhs.init () -let () = Pyrat.init () +let () = + Pyrat.init (); + Marabou.init () (* -- Logs. *) diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml new file mode 100644 index 00000000..28a6c84d --- /dev/null +++ b/src/printers/marabou.ml @@ -0,0 +1,110 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +type info = { + info_syn : Why3.Printer.syntax_map; + variables : string Why3.Term.Hls.t; +} + +let number_format = + { + Why3.Number.long_int_support = `Default; + Why3.Number.negative_int_support = `Default; + Why3.Number.dec_int_support = `Default; + Why3.Number.hex_int_support = `Unsupported; + Why3.Number.oct_int_support = `Unsupported; + Why3.Number.bin_int_support = `Unsupported; + Why3.Number.negative_real_support = `Default; + Why3.Number.dec_real_support = `Default; + Why3.Number.hex_real_support = `Unsupported; + Why3.Number.frac_real_support = `Unsupported (fun _ _ -> assert false); + } + +let rec print_term info fmt t = + let open Why3 in + match t.Term.t_node with + | Tbinop ((Timplies | Tiff | Tor), _, _) + | Tnot _ | Ttrue | Tfalse | Tvar _ | Tlet _ | Tif _ | Tcase _ | Tquant _ + | Teps _ -> + Printer.unsupportedTerm t "Not supported by Marabou" + | Tbinop (Tand, _, _) -> assert false (* Should appear only at top-level. *) + | Tconst c -> Constant.(print number_format unsupported_escape) fmt c + | Tapp (ls, l) -> ( + match Printer.query_syntax info.info_syn ls.ls_name with + | Some s -> Printer.syntax_arguments s (print_term info) fmt l + | None -> ( + match (Term.Hls.find_opt info.variables ls, l) with + | Some s, [] -> Fmt.string fmt s + | _ -> Printer.unsupportedTerm t "Unknown variable(s)")) + +let rec print_top_level_term info fmt t = + let open Why3 in + (* Don't print things we don't know. *) + let t_is_known = + Term.t_s_all + (fun _ -> true) + (fun ls -> + Ident.Mid.mem ls.ls_name info.info_syn || Term.Hls.mem info.variables ls) + in + match t.Term.t_node with + | Tquant _ -> () + | Tbinop (Tand, t1, t2) -> + if t_is_known t1 && t_is_known t2 + then + Fmt.pf fmt "%a%a" + (print_top_level_term info) + t1 + (print_top_level_term info) + t2 + | _ -> if t_is_known t then Fmt.pf fmt "%a@." (print_term info) t + +let print_decl info fmt d = + let open Why3 in + match d.Decl.d_node with + | Dtype _ -> () + | Ddata _ -> () + | Dparam _ -> () + | Dlogic _ -> () + | Dind _ -> () + | Dprop (Decl.Plemma, _, _) -> assert false + | Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f + | Dprop (Decl.Pgoal, _, f) -> print_top_level_term info fmt f + +let rec print_tdecl info fmt task = + let open Why3 in + match task with + | None -> () + | Some { Task.task_prev; task_decl; _ } -> ( + print_tdecl info fmt task_prev; + match task_decl.Theory.td_node with + | Use _ | Clone _ -> () + | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input + -> ( + match l with + | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i) + | _ -> assert false) + | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output + -> ( + match l with + | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i) + | _ -> assert false) + | Meta (_, _) -> () + | Decl d -> print_decl info fmt d) + +let print_task args ?old:_ fmt task = + let open Why3 in + let info = + { + info_syn = Discriminate.get_syntax_map task; + variables = Term.Hls.create 10; + } + in + Printer.print_prelude fmt args.Printer.prelude; + print_tdecl info fmt task + +let init () = + Why3.Printer.register_printer ~desc:"Printer for the Marabou prover." + "marabou" print_task diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index fd3a3800..95b89b07 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -78,27 +78,29 @@ let simplify_goal env input_variables = | Decl decl -> let meta = ref [] in let hls = ref [] in - let map term = - let term = aux meta hls term in - if List.is_empty !hls - then term - else - let known = - List.fold !hls ~init:task_hd.task_known ~f:(fun acc (d, _, _) -> - Decl.known_add_decl acc d) - in - let engine = - Reduction_engine.create - { - compute_defs = false; - compute_builtin = true; - compute_def_set = Term.Sls.empty; - } - env known - in - Reduction_engine.normalize ~limit:100 engine Term.Mvs.empty term + let decl = + Decl.decl_map + (fun term -> + let term = aux meta hls term in + if List.is_empty !hls + then term + else + let known = + List.fold !hls ~init:task_hd.task_known + ~f:(fun acc (d, _, _) -> Decl.known_add_decl acc d) + in + let engine = + Reduction_engine.create + { + compute_defs = false; + compute_builtin = true; + compute_def_set = Term.Sls.empty; + } + env known + in + Reduction_engine.normalize ~limit:100 engine Term.Mvs.empty term) + decl in - let decl = Decl.decl_map map decl in let acc = List.fold !hls ~init:acc ~f:(fun acc (d, ls, i) -> let task = Task.add_decl acc d in diff --git a/src/transformations/vars_on_lhs.ml b/src/transformations/vars_on_lhs.ml new file mode 100644 index 00000000..021ac66a --- /dev/null +++ b/src/transformations/vars_on_lhs.ml @@ -0,0 +1,46 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +let make_rt env = + let open Why3 in + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + let le = Theory.ns_find_ls th.th_export [ "le" ] in + let lt = Theory.ns_find_ls th.th_export [ "lt" ] in + let ge = Theory.ns_find_ls th.th_export [ "ge" ] in + let gt = Theory.ns_find_ls th.th_export [ "gt" ] in + let rec rt t = + let t = Term.t_map rt t in + match t.t_node with + | Tapp + ( ls, + [ + ({ t_node = Tconst _; _ } as const); + ({ t_node = Tapp (_, []); _ } as var); + ] ) -> + Fmt.pr "HERE: %a!@." Pretty.print_term t; + if Term.ls_equal ls le + then Term.ps_app ge [ var; const ] + else if Term.ls_equal ls ge + then Term.ps_app le [ var; const ] + else if Term.ls_equal ls lt + then Term.ps_app gt [ var; const ] + else if Term.ls_equal ls gt + then Term.ps_app lt [ var; const ] + else t + | _ -> t + in + rt + +let vars_on_lhs env = + let rt = make_rt env in + Why3.Trans.rewrite rt None + +let init () = + Why3.Trans.register_env_transform + ~desc: + "Transformation for provers that need variables on the left-hand-side of \ + logic symbols." + "vars_on_lhs" vars_on_lhs diff --git a/src/verification.ml b/src/verification.ml index 1aeefd3a..cb38ab34 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -49,6 +49,7 @@ let call_prover ~limit prover driver task = let verify format loadpath ?memlimit ~prover file = let open Why3 in + (* Debug.(set_flag (lookup_flag "call_prover")); *) let env, config = create_env loadpath in let steplimit = None in let timeout = None in -- GitLab