From b3c28279d53e4984416f716a1bc6f2003a40b7a1 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 30 Nov 2022 17:36:16 +0100 Subject: [PATCH] [verification][printers] Allow Marabou to be launched on a dataset verification. Modify the printer and transformations for Marabou to accept working on reals. --- config/drivers/marabou.drv | 8 +-- src/printers/marabou.ml | 88 ++++++++++++++++++++---------- src/transformations/vars_on_lhs.ml | 54 +++++++++++++----- src/verification.ml | 31 ++++++++--- 4 files changed, 127 insertions(+), 54 deletions(-) diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv index 25eb4ea3..098c2aa6 100644 --- a/config/drivers/marabou.drv +++ b/config/drivers/marabou.drv @@ -123,10 +123,10 @@ theory real.Real meta "invalid trigger" predicate (>=) meta "invalid trigger" predicate (>) - syntax predicate (<=) "(%1 <= %2)" - syntax predicate (<) "(%1 < %2)" - syntax predicate (>=) "(%1 >= %2)" - syntax predicate (>) "(%1 > %2)" + syntax predicate (<=) "%1 <= %2" + syntax predicate (<) "%1 < %2" + syntax predicate (>=) "%1 >= %2" + syntax predicate (>) "%1 > %2" remove prop MulComm.Comm remove prop MulAssoc.Assoc diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml index 4a58c113..67fc24f9 100644 --- a/src/printers/marabou.ml +++ b/src/printers/marabou.ml @@ -22,11 +22,16 @@ open Why3 -type info = { +type relops = { le : Term.lsymbol; ge : Term.lsymbol; lt : Term.lsymbol; gt : Term.lsymbol; +} + +type info = { + ls_rel_real : relops; + ls_rel_float : relops; info_syn : Printer.syntax_map; variables : string Term.Hls.t; } @@ -70,13 +75,17 @@ let rec print_term info fmt t = Term.Hls.find_opt info.variables ls2 ) with | Some s1, Some s2 -> - if Term.ls_equal ls info.le + if Term.ls_equal ls info.ls_rel_float.le + || Term.ls_equal ls info.ls_rel_real.le then Fmt.pf fmt "+%s -%s <= 0" s1 s2 - else if Term.ls_equal ls info.ge + else if Term.ls_equal ls info.ls_rel_float.ge + || Term.ls_equal ls info.ls_rel_real.ge then Fmt.pf fmt "+%s -%s >= 0" s1 s2 - else if Term.ls_equal ls info.lt + else if Term.ls_equal ls info.ls_rel_float.lt + || Term.ls_equal ls info.ls_rel_real.lt then Fmt.pf fmt "+%s -%s < 0" s1 s2 - else if Term.ls_equal ls info.gt + else if Term.ls_equal ls info.ls_rel_float.gt + || Term.ls_equal ls info.ls_rel_real.gt then Fmt.pf fmt "+%s -%s > 0" s1 s2 else Printer.unsupportedTerm t "Unknown relational operator" | _ -> Printer.unsupportedTerm t "Unknown variable(s)") @@ -116,24 +125,38 @@ let rec negate_term info t = Term.t_and (negate_term info t1) (negate_term info t2) | Tapp (ls, [ t1; t2 ]) -> let tt = [ t1; t2 ] in - if Term.ls_equal ls info.le - then Term.ps_app info.gt tt - else if Term.ls_equal ls info.ge - then Term.ps_app info.lt tt - else if Term.ls_equal ls info.lt - then Term.ps_app info.ge tt - else if Term.ls_equal ls info.gt - then Term.ps_app info.le tt - else Printer.unsupportedTerm t "Cannot negate such term" + (* Negate float relational symbols. *) + let ls_neg = + if Term.ls_equal ls info.ls_rel_float.le + then info.ls_rel_float.gt + else if Term.ls_equal ls info.ls_rel_float.ge + then info.ls_rel_float.lt + else if Term.ls_equal ls info.ls_rel_float.lt + then info.ls_rel_float.ge + else if Term.ls_equal ls info.ls_rel_float.gt + then info.ls_rel_float.le + else ls + in + (* Negate real relational symbols. *) + let ls_neg = + if Term.ls_equal ls info.ls_rel_real.le + then info.ls_rel_real.gt + else if Term.ls_equal ls info.ls_rel_real.ge + then info.ls_rel_real.lt + else if Term.ls_equal ls info.ls_rel_real.lt + then info.ls_rel_real.ge + else if Term.ls_equal ls info.ls_rel_real.gt + then info.ls_rel_real.le + else ls_neg + in + if Term.ls_equal ls_neg ls + then Printer.unsupportedTerm t "Cannot negate such term" + else Term.ps_app ls_neg tt | _ -> Printer.unsupportedTerm t "Cannot negate such term" let print_decl info fmt d = match d.Decl.d_node with - | Dtype _ -> () - | Ddata _ -> () - | Dparam _ -> () - | Dlogic _ -> () - | Dind _ -> () + | Dtype _ | Ddata _ | Dparam _ | Dlogic _ | Dind _ -> () | Dprop (Decl.Plemma, _, _) -> assert false | Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f | Dprop (Decl.Pgoal, _, f) -> @@ -158,17 +181,26 @@ let rec print_tdecl info fmt task = | Decl d -> print_decl info fmt d) let print_task args ?old:_ fmt task = - let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in - let le = Theory.ns_find_ls th.th_export [ "le" ] in - let lt = Theory.ns_find_ls th.th_export [ "lt" ] in - let ge = Theory.ns_find_ls th.th_export [ "ge" ] in - let gt = Theory.ns_find_ls th.th_export [ "gt" ] in + let ls_rel_real = + let th = Env.read_theory args.Printer.env [ "real" ] "Real" in + let le = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in + let lt = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in + let ge = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in + let gt = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in + { le; ge; lt; gt } + in + let ls_rel_float = + let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in + let le = Theory.ns_find_ls th.th_export [ "le" ] in + let lt = Theory.ns_find_ls th.th_export [ "lt" ] in + let ge = Theory.ns_find_ls th.th_export [ "ge" ] in + let gt = Theory.ns_find_ls th.th_export [ "gt" ] in + { le; ge; lt; gt } + in let info = { - le; - lt; - ge; - gt; + ls_rel_real; + ls_rel_float; info_syn = Discriminate.get_syntax_map task; variables = Term.Hls.create 10; } diff --git a/src/transformations/vars_on_lhs.ml b/src/transformations/vars_on_lhs.ml index 6bed9204..909a4087 100644 --- a/src/transformations/vars_on_lhs.ml +++ b/src/transformations/vars_on_lhs.ml @@ -26,10 +26,15 @@ open Base let make_rt env = let th = Env.read_theory env [ "ieee_float" ] "Float64" in let t = Theory.ns_find_ts th.th_export [ "t" ] in - let le = Theory.ns_find_ls th.th_export [ "le" ] in - let lt = Theory.ns_find_ls th.th_export [ "lt" ] in - let ge = Theory.ns_find_ls th.th_export [ "ge" ] in - let gt = Theory.ns_find_ls th.th_export [ "gt" ] in + let le_float = Theory.ns_find_ls th.th_export [ "le" ] in + let lt_float = Theory.ns_find_ls th.th_export [ "lt" ] in + let ge_float = Theory.ns_find_ls th.th_export [ "ge" ] in + let gt_float = Theory.ns_find_ls th.th_export [ "gt" ] in + let th = Env.read_theory env [ "real" ] "Real" in + let le_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in + let lt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in + let ge_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in + let gt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in let rec rt t = let t = Term.t_map rt t in match t.t_node with @@ -40,19 +45,40 @@ let make_rt env = ({ t_node = Tapp (_, []); _ } as var); ] ) -> let tt = [ var; const ] in - if Term.ls_equal ls le - then Term.ps_app ge tt - else if Term.ls_equal ls ge - then Term.ps_app le tt - else if Term.ls_equal ls lt - then Term.ps_app gt tt - else if Term.ls_equal ls gt - then Term.ps_app lt tt - else t + let ls_rel = + if Term.ls_equal ls le_float + then ge_float + else if Term.ls_equal ls ge_float + then le_float + else if Term.ls_equal ls lt_float + then gt_float + else if Term.ls_equal ls gt_float + then lt_float + else ls + in + let ls_rel = + if Term.ls_equal ls le_real + then ge_real + else if Term.ls_equal ls ge_real + then le_real + else if Term.ls_equal ls lt_real + then gt_real + else if Term.ls_equal ls gt_real + then lt_real + else ls_rel + in + if Term.ls_equal ls_rel ls then t else Term.ps_app ls_rel tt | _ -> t in let task = - List.fold [ le; lt; ge; gt ] ~init:(Task.add_ty_decl None t) + List.fold + [ le_float; lt_float; ge_float; gt_float ] + ~init:(Task.add_ty_decl None t) ~f:Task.add_param_decl + in + let task = + List.fold + [ le_real; lt_real; ge_real; gt_real ] + ~init:(Task.add_ty_decl task Ty.ts_real) ~f:Task.add_param_decl in (rt, task) diff --git a/src/verification.ml b/src/verification.ml index bad75796..96dba8fd 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -223,7 +223,8 @@ let combine_prover_answers answers = | Call_provers.Valid, r | r, Call_provers.Valid -> r | _ -> acc) -let answer_on_dataset limit config env config_prover driver dataset_csv task = +let answer_on_dataset limit config env prover config_prover driver dataset_csv + task = let dataset_predicate = let on_model ls = let message = @@ -239,22 +240,35 @@ let answer_on_dataset limit config env config_prover driver dataset_csv task = in Dataset.interpret_predicate env ~on_model ~on_dataset task in - let tasks = + let dataset_tasks = + (* One task for dataset element. *) Dataset.tasks_of_nn_csv_predicate env dataset_predicate |> List.map ~f:(Driver.prepare_task driver) in + let tasks = + (* We turn each task in [dataset_tasks] into a list (ie, conjunction) of + disjunctions of tasks. *) + match prover with + | Prover.Marabou -> List.map ~f:(Trans.apply Split.split_all) dataset_tasks + | _ -> [ dataset_tasks ] + in let command = Whyconf.get_complete_command ~with_steps:false config_prover in let command = let nn_file = Unix.realpath dataset_predicate.model.filename in Re__Core.replace_string nnet_or_onnx ~by:nn_file command in let answers = - List.map tasks ~f:(call_prover_on_task limit config command driver) + List.map tasks ~f:(fun dataset_elt_tasks -> + let dataset_elt_answers = + List.map dataset_elt_tasks + ~f:(call_prover_on_task limit config command driver) + in + combine_prover_answers dataset_elt_answers) in let prover_answer = combine_prover_answers answers in (prover_answer, None) -let answer_generic limit env config prover config_prover driver task = +let answer_generic limit config env prover config_prover driver task = let task = Proof_strategy.apply_native_nn_prover env task in let task = Driver.prepare_task driver task in let nn_file = @@ -264,7 +278,7 @@ let answer_generic limit env config prover config_prover driver task = | None -> invalid_arg "No neural network model found in task" in let tasks = - (* We make [tasks] as a list (ie, conjunction) of disjunctions. *) + (* We turn [task] into a list (ie, conjunction) of disjunctions of tasks. *) match prover with | Prover.Marabou -> Trans.apply Split.split_all task | Pyrat -> Trans.apply Split.split_premises task @@ -283,11 +297,12 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset_csv task - | Pyrat when Option.is_some dataset_csv -> + | (Marabou | Pyrat) when Option.is_some dataset_csv -> let dataset_csv = Option.value_exn dataset_csv in - answer_on_dataset limit config env config_prover driver dataset_csv task + answer_on_dataset limit config env prover config_prover driver dataset_csv + task | Marabou | Pyrat | CVC5 -> - answer_generic limit env config prover config_prover driver task + answer_generic limit config env prover config_prover driver task in Logs.app (fun m -> m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) -- GitLab