Skip to content
Snippets Groups Projects
Commit eed0c97f authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[SAVer] Handling of a "robust_to" predicate for SAVer.

parent c620b39f
No related branches found
No related tags found
No related merge requests found
all: all:
dune build --root=. @install caisar.opam nnet.opam onnx.opam dune build --root=. @install caisar.opam nnet.opam onnx.opam ovo.opam
test: test:
dune runtest --root=. dune runtest --root=.
......
...@@ -2,14 +2,28 @@ ...@@ -2,14 +2,28 @@
## How to add a solver ## How to add a solver
Make sure the solver is added to your systems for tests. Make sure the solver is added to your systems for tests.
Simplest way to do that is to add the binary directly to the
root of CAISAR repo.
1. Create a `new_solver.drv` in `config/drivers/`. 1. Create a `new_solver.drv` in `config/drivers/`.
A driver is a serie of whyml modules describing the theories A driver is a serie of whyml modules describing the theories
the solver will use during its call by why3. the solver will use during its call by why3. It is also
parsing the output of the solver and sending it to why3
inner representation.
1. Add a new field inside of `config/caisar-detection-data.conf`. 1. Add a new field inside of `config/caisar-detection-data.conf`.
Here, you only need to define the name of the solver's Here, you need to define the name of the solver's
executable as well as the supported versions for CAISAR. executable as well as the supported versions for CAISAR.
You also need to define the command that the solver will
execute. There are several Why3 built-in identifiers:
* %f is for a file
* %e is for the executable
We also added custom identifiers: %{nnet-onnx} and %{svm}.
Those identifiers will be replaced in `src/verification.ml`
by their actual value (the model filename).
Now, the solver will be recognized by CAISAR. However, in 1. Now, the solver will be recognized by CAISAR. However, in
order to exploit it, you should write a printer of its order to exploit it, you should write a printer of its
output in `src/printers/`. TODO output in `src/printers/`. This printer should have a `init`
function that must be called at the top of `src/main.ml`.
A printer is something that, given a Why3 formula, transform
it into something the solver can use.
...@@ -36,6 +36,6 @@ exec = "saver" ...@@ -36,6 +36,6 @@ exec = "saver"
version_switch = "--version 2>&1 | cat" version_switch = "--version 2>&1 | cat"
version_regexp = "\\(v[0-9.]+\\)" version_regexp = "\\(v[0-9.]+\\)"
version_ok = "v1.0" version_ok = "v1.0"
command = "%e %{svm} %{data}" command = "%e %{svm} %{dataset} hybrid l_inf %{epsilon}"
driver = "caisar_drivers/saver.drv" driver = "caisar_drivers/saver.drv"
use_at_auto_level = 1 use_at_auto_level = 1
...@@ -12,7 +12,6 @@ transformation "inline_trivial" ...@@ -12,7 +12,6 @@ transformation "inline_trivial"
transformation "introduce_premises" transformation "introduce_premises"
transformation "eliminate_builtin" transformation "eliminate_builtin"
transformation "simplify_formula" transformation "simplify_formula"
transformation "native_nn_prover"
theory BuiltIn theory BuiltIn
syntax type int "int" syntax type int "int"
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
(files caisar-detection-data.conf (files caisar-detection-data.conf
(drivers/pyrat.drv as drivers/pyrat.drv) (drivers/pyrat.drv as drivers/pyrat.drv)
(drivers/marabou.drv as drivers/marabou.drv) (drivers/marabou.drv as drivers/marabou.drv)
(drivers/saver.drv as drivers/saver.drv)
) )
(package caisar)) (package caisar))
...@@ -43,29 +43,27 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env = ...@@ -43,29 +43,27 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env =
in in
Wstdlib.Mstr.singleton "NNasTuple" (Pmodule.close_module th_uc) Wstdlib.Mstr.singleton "NNasTuple" (Pmodule.close_module th_uc)
let register_asarray nb_inputs nb_outputs filename env = let register_svm nb_inputs nb_outputs filename env =
let open Why3 in let open Why3 in
let net = Pmodule.read_module env [ "caisar" ] "IOShape" in let net = Pmodule.read_module env [ "caisar" ] "SVM" in
let ioshape_input_type = let asarray_input_type =
Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) [] Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) []
in in
let id_as_array = Ident.id_fresh "AsArray" in let id_as_array = Ident.id_fresh "SVM" in
let th_uc = Pmodule.create_module env id_as_array in let th_uc = Pmodule.create_module env id_as_array in
let th_uc = Pmodule.use_export th_uc net in let th_uc = Pmodule.use_export th_uc net in
let ls_svm_apply = let ls_svm_apply =
let f _ = ioshape_input_type in
Term.create_fsymbol Term.create_fsymbol
(Ident.id_fresh "svm_apply") (Ident.id_fresh "svm_apply")
(List.init nb_inputs ~f) [] (Ty.ty_func asarray_input_type asarray_input_type)
(Ty.ty_tuple (List.init nb_outputs ~f))
in in
Why3.Term.Hls.add loaded_nets ls_svm_apply Why3.Term.Hls.add loaded_nets ls_svm_apply
{ filename; nb_inputs; nb_outputs; ty_data = ioshape_input_type }; { filename; nb_inputs; nb_outputs; ty_data = asarray_input_type };
let th_uc = let th_uc =
Pmodule.add_pdecl ~vc:false th_uc Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_svm_apply)) (Pdecl.create_pure_decl (Decl.create_param_decl ls_svm_apply))
in in
Wstdlib.Mstr.singleton "AsArray" (Pmodule.close_module th_uc) Wstdlib.Mstr.singleton "SVM" (Pmodule.close_module th_uc)
let nnet_parser env _ filename _ = let nnet_parser env _ filename _ =
let model = Nnet.parse filename in let model = Nnet.parse filename in
...@@ -84,7 +82,7 @@ let ovo_parser env _ filename _ = ...@@ -84,7 +82,7 @@ let ovo_parser env _ filename _ =
let model = Ovo.parse filename in let model = Ovo.parse filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok model -> register_asarray model.n_inputs model.n_outputs filename env | Ok model -> register_svm model.n_inputs model.n_outputs filename env
let register_nnet_support () = let register_nnet_support () =
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
......
...@@ -15,7 +15,8 @@ let () = ...@@ -15,7 +15,8 @@ let () =
let () = let () =
Pyrat.init (); Pyrat.init ();
Marabou.init () Marabou.init ();
Saver.init()
(* -- Logs. *) (* -- Logs. *)
......
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
type info = {
info_syn : Why3.Printer.syntax_map;
variables : string Why3.Term.Hls.t;
}
let number_format =
{
Why3.Number.long_int_support = `Default;
Why3.Number.negative_int_support = `Default;
Why3.Number.dec_int_support = `Default;
Why3.Number.hex_int_support = `Unsupported;
Why3.Number.oct_int_support = `Unsupported;
Why3.Number.bin_int_support = `Unsupported;
Why3.Number.negative_real_support = `Default;
Why3.Number.dec_real_support = `Default;
Why3.Number.hex_real_support = `Unsupported;
Why3.Number.frac_real_support = `Unsupported (fun _ _ -> assert false);
}
let rec print_term info fmt t =
let open Why3 in
match t.Term.t_node with
| Tbinop ((Timplies | Tiff), _, _)
| Tnot _ | Ttrue | Tfalse | Tvar _ | Tlet _ | Tif _ | Tcase _ | Tquant _
| Teps _
| Tbinop (Tor, _, _) ->
Printer.unsupportedTerm t "Not supported by SAVER"
| Tbinop (Tand, _, _) -> assert false (* Should appear only at top-level. *)
| Tconst c -> Constant.(print number_format unsupported_escape) fmt c
| Tapp (ls, l) -> (
match Printer.query_syntax info.info_syn ls.ls_name with
| Some s ->
(* Only print constants in a csv-like manner when
* encountering an equality. Constant is expected to
* be on the right side of the declaration *)
if String.contains s '=' then print_term info fmt (List.nth l 1) else ()
| None -> ())
let rec print_top_level_term info fmt t =
let open Why3 in
(* Don't print things we don't know. *)
let t_is_known =
Term.t_s_all
(fun _ -> true)
(fun ls ->
Ident.Mid.mem ls.ls_name info.info_syn || Term.Hls.mem info.variables ls)
in
match t.Term.t_node with
| Tquant _ -> ()
| Tbinop (Tand, t1, t2) ->
if t_is_known t1 && t_is_known t2
then
Fmt.pf fmt "%a%a"
(print_top_level_term info)
t1
(print_top_level_term info)
t2
| _ -> if t_is_known t then Fmt.pf fmt ",%a" (print_term info) t
let print_decl info fmt d =
let open Why3 in
match d.Decl.d_node with
| Dtype _ -> ()
| Ddata _ -> ()
| Dparam _ -> ()
| Dlogic _ -> ()
| Dind _ -> ()
| Dprop (Decl.Plemma, _, _) -> assert false
| Dprop (Decl.Paxiom, _, f) -> print_top_level_term info fmt f
| Dprop (Decl.Pgoal, _, f) -> print_top_level_term info fmt f
let rec print_tdecl info fmt task =
let open Why3 in
match task with
| None -> ()
| Some { Task.task_prev; task_decl; _ } -> (
print_tdecl info fmt task_prev;
match task_decl.Theory.td_node with
| Use _ | Clone _ -> ()
| Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input
-> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
| _ -> assert false)
| Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output
-> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
| _ -> assert false)
| Meta (_, _) -> ()
| Decl d -> print_decl info fmt d)
let print_task args ?old:_ fmt task =
let open Why3 in
let info =
{
info_syn = Discriminate.get_syntax_map task;
variables = Term.Hls.create 10;
}
in
Printer.print_prelude fmt args.Printer.prelude;
Format.fprintf fmt "# 1 3\n1";
print_tdecl info fmt task
let init () =
Why3.Printer.register_printer ~desc:"Printer for the SAVer prover." "saver"
print_task
...@@ -45,6 +45,9 @@ let create_env loadpath = ...@@ -45,6 +45,9 @@ let create_env loadpath =
config ) config )
let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}")) let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}"))
let svm = Re__Core.(compile (str "%{svm}"))
let dataset = Re__Core.(compile (str "%{dataset}"))
let epsilon = Re__Core.(compile (str "%{epsilon}"))
let combine_prover_answers answers = let combine_prover_answers answers =
List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r -> List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r ->
...@@ -52,35 +55,124 @@ let combine_prover_answers answers = ...@@ -52,35 +55,124 @@ let combine_prover_answers answers =
| Call_provers.Valid, r | r, Call_provers.Valid -> r | Call_provers.Valid, r | r, Call_provers.Valid -> r
| _ -> acc) | _ -> acc)
let call_prover ~limit config (prover : Whyconf.config_prover) driver task = let handle_task_saver task env command =
let task_prepared = Driver.prepare_task driver task in let open Why3 in
let tasks = let dataset_filename = "test_data.csv" in
(* We make [tasks] as a list (ie, conjunction) of disjunctions. *) let goal = Task.task_goal_fmla task in
if String.equal prover.prover.prover_name "Marabou" let eps, svm_filename =
then Trans.apply Split_goal.split_goal_full task_prepared match goal.t_node with
else [ task_prepared ] | Tquant (Tforall, b) -> (
in let _, _, pred = Term.t_open_quant b in
let command = Whyconf.get_complete_command ~with_steps:false prover in let svm_t = Pmodule.read_module env [ "caisar" ] "SVM" in
let nn_file = let robust_to_predicate =
match Task.on_meta_excl Native_nn_prover.meta_nn_filename task_prepared with let open Theory in
| Some [ MAstr nn_file ] -> nn_file ns_find_ls svm_t.mod_theory.th_export [ "robust_to" ]
| Some _ -> assert false (* By construction of the meta. *) in
| None -> invalid_arg (Fmt.str "No neural network model found in task") match pred.t_node with
| Term.Tapp
( ls,
[
{ t_node = Tapp (svm_app_sym, _); _ }; _; { t_node = Tconst e; _ };
] ) ->
if Term.ls_equal ls robust_to_predicate
then
let eps = Fmt.str "%a" Constant.print_def e in
let svm_filename =
match Language.lookup_loaded_nets svm_app_sym with
| Some t -> t.filename
| None -> failwith "Svm file not found in environment."
in
(eps, svm_filename)
else failwith "Wrong predicate found."
(* no other predicate than robust_to is supported *)
| _ -> failwith "Unsupported by SAVer.")
| _ -> failwith "Unsupported by SAVer."
in in
let nn_file = Unix.realpath nn_file in let svm_file = Filename.concat (Caml.Sys.getcwd ()) svm_filename in
let command = Re__Core.replace_string nnet_or_onnx ~by:nn_file command in let dataset_file = Filename.concat (Caml.Sys.getcwd ()) dataset_filename in
let call_prover_on_task task_prepared = let command = Re__Core.replace_string svm ~by:svm_file command in
let command = Re__Core.replace_string dataset ~by:dataset_file command in
let command = Re__Core.replace_string epsilon ~by:eps command in
command
let call_prover ~limit config (prover : Why3.Whyconf.config_prover) driver env
task =
let open Why3 in
if String.equal prover.prover.prover_name "SAVer"
then
let command = Whyconf.get_complete_command ~with_steps:false prover in
let command = handle_task_saver task env command in
let res_parser =
{
Call_provers.prp_regexps =
[ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
prp_timeregexps = [];
prp_stepregexps = [];
prp_exitcodes = [];
prp_model_parser = Model_parser.lookup_model_parser "no_model";
}
in
let prover_call = let prover_call =
Driver.prove_task_prepared ~libdir:(Whyconf.libdir config) Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit driver task_prepared ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
~filename:"foo.txt" ~get_counterexmp:false ~gen_new_file:false
~printing_info:Printer.default_printing_info (Buffer.create 10)
in in
let prover_result = Call_provers.wait_on_call prover_call in let prover_result = Call_provers.wait_on_call prover_call in
prover_result.pr_answer let answer =
in match prover_result.pr_answer with
let answers = List.map tasks ~f:call_prover_on_task in | Call_provers.HighFailure -> (
let answer = combine_prover_answers answers in let pr_output = prover_result.pr_output in
Fmt.pr "Goal %a: %a@." Pretty.print_pr (Task.task_goal task) let matcher =
Call_provers.print_prover_answer answer Re__Pcre.regexp
"\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d"
in
match Re__Core.exec_opt matcher pr_output with
| Some g ->
if Int.of_string (Re__Core.Group.get g 1)
= Int.of_string (Re__Core.Group.get g 2)
then Call_provers.Valid
else Call_provers.Invalid
| None -> assert false
(* Any other answer than HighFailure should
* never happen as we do not define
* anything in SAVer's driver *))
| _ -> assert false
in
Fmt.pr "Goal %a: %a@." Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer answer
else
let task_prepared = Driver.prepare_task driver task in
let tasks =
(* We make [tasks] as a list (ie, conjunction) of disjunctions. *)
if String.equal prover.prover.prover_name "Marabou"
then Trans.apply Split_goal.split_goal_full task_prepared
else [ task_prepared ]
in
let command = Whyconf.get_complete_command ~with_steps:false prover in
let nn_file =
match
Task.on_meta_excl Native_nn_prover.meta_nn_filename task_prepared
with
| Some [ MAstr nn_file ] -> nn_file
| Some _ -> assert false (* By construction of the meta. *)
| None -> invalid_arg (Fmt.str "No neural network model found in task")
in
let nn_file = Unix.realpath nn_file in
let command = Re__Core.replace_string nnet_or_onnx ~by:nn_file command in
let call_prover_on_task task_prepared =
let prover_call =
Driver.prove_task_prepared ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit driver task_prepared
in
let prover_result = Call_provers.wait_on_call prover_call in
prover_result.pr_answer
in
let answers = List.map tasks ~f:call_prover_on_task in
let answer = combine_prover_answers answers in
Fmt.pr "Goal %a: %a@." Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer answer
let verify ?(debug = false) format loadpath ?memlimit ?timeout prover file = let verify ?(debug = false) format loadpath ?memlimit ?timeout prover file =
if debug then Debug.set_flag (Debug.lookup_flag "call_prover"); if debug then Debug.set_flag (Debug.lookup_flag "call_prover");
...@@ -125,6 +217,6 @@ let verify ?(debug = false) format loadpath ?memlimit ?timeout prover file = ...@@ -125,6 +217,6 @@ let verify ?(debug = false) format loadpath ?memlimit ?timeout prover file =
Driver.load_driver_absolute env file prover.extra_drivers Driver.load_driver_absolute env file prover.extra_drivers
in in
List.iter List.iter
~f:(call_prover ~limit (Whyconf.get_main config) prover driver) ~f:(call_prover ~limit (Whyconf.get_main config) prover driver env)
tasks) tasks)
mstr_theory mstr_theory
...@@ -2,3 +2,10 @@ theory NN ...@@ -2,3 +2,10 @@ theory NN
use ieee_float.Float64 use ieee_float.Float64
type input_type = t type input_type = t
end end
theory SVM
use ieee_float.Float64
use array.Array
type input_type = array t
predicate robust_to (input_type -> input_type) (input_type) (t)
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment