Skip to content
Snippets Groups Projects
Commit 9718c499 authored by Michele Alberti's avatar Michele Alberti
Browse files

[SAVer] Extend SAVer support for correct and conditional robust predicates.

parent 80a03afa
No related branches found
No related tags found
No related merge requests found
......@@ -40,51 +40,48 @@ let handle_ovo_line ~f in_channel =
~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Csv.next in_channel)
(* Skip the header part, ie comments, of the OVO format. *)
let skip_ovo_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "OVO model not found in file `%s'." filename)
(* Handle ovo first line: either 'ovo' or 'ovo x y' with 'x' and 'y' positive
integer numbers. *)
let handle_ovo_first_line in_channel =
let ovo_format_error_on_first_line = ovo_format_error "first line" in
match Csv.next in_channel with
| [ first_line ] -> (
let in_channel = Csv.of_string ~separator:' ' first_line in
match Csv.next in_channel with
| [ "ovo"; n_is; n_os ] ->
let n_is = Int.of_string (String.strip n_is) in
let n_os = Int.of_string (String.strip n_os) in
Ok (Some (n_is, n_os))
| [ "ovo" ] -> Ok None
| _ -> ovo_format_error_on_first_line
| exception End_of_file -> ovo_format_error_on_first_line)
| _ -> ovo_format_error_on_first_line
| exception End_of_file -> ovo_format_error_on_first_line
(* Retrieve inputs and outputs size. *)
let handle_ovo_basic_info in_channel =
let handle_ovo_basic_info ~descr in_channel =
match handle_ovo_line ~f:Int.of_string in_channel with
| [ dim ] -> Ok dim
| _ -> ovo_format_error "first"
| exception End_of_file -> ovo_format_error "first"
(* Skip unused flag. *)
let handle_ovo_unused_flag in_channel =
try
let _ = Csv.next in_channel in
Ok ()
with End_of_file -> ovo_format_error "second"
| _ -> ovo_format_error descr
| exception End_of_file -> ovo_format_error descr
(* Retrieves [filename] OVO model metadata and weights wrt OVO format
specification, which is described here:
https://github.com/abstract-machine-learning/saver#classifier-format. *)
let parse_in_channel filename in_channel =
https://github.com/abstract-machine-learning/saver#classifier-format. In
practice, the first line may specify the input/output size values or not. In
the latter case, we search for input/output size values in the second/third
lines respectively. *)
let parse_in_channel in_channel =
let open Result in
try
skip_ovo_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_ovo_unused_flag in_channel >>= fun () ->
handle_ovo_basic_info in_channel >>= fun n_is ->
handle_ovo_basic_info in_channel >>= fun n_os ->
(handle_ovo_first_line in_channel >>= function
| Some (n_is, n_os) -> Ok (n_is, n_os)
| None ->
handle_ovo_basic_info ~descr:"input size" in_channel >>= fun n_is ->
handle_ovo_basic_info ~descr:"output size" in_channel >>= fun n_os ->
Ok (n_is, n_os))
>>= fun (n_is, n_os) ->
Csv.close_in in_channel;
Ok { n_inputs = n_is; n_outputs = n_os }
with
......@@ -96,4 +93,4 @@ let parse filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel filename in_channel)
(fun () -> parse_in_channel in_channel)
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
open Base
type predicate_kind = Correct | Robust | CondRobust
type interpreted_task = {
svm_filename : string;
predicate_kind : predicate_kind;
eps : string;
}
let find_predicate_ls env p =
let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in
Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ]
let interpret_predicate env ls =
let correct_predicate = find_predicate_ls env "correct" in
let robust_predicate = find_predicate_ls env "robust" in
let cond_robust_predicate = find_predicate_ls env "cond_robust" in
if Term.ls_equal ls correct_predicate
then Correct
else if Term.ls_equal ls robust_predicate
then Robust
else if Term.ls_equal ls cond_robust_predicate
then CondRobust
else failwith (Fmt.str "Unsupported by SAVer: %a" Pretty.print_ls ls)
let interpret_task env task =
let goal = Task.task_goal_fmla task in
match goal.t_node with
| Term.Tapp
( ls,
[
{ t_node = Tapp (ls_svm_apply, _); _ };
_dataset;
{ t_node = Tconst e; _ };
] ) ->
let predicate_kind = interpret_predicate env ls in
let eps = Fmt.str "%a" Constant.print_def e in
let svm_filename =
match Language.lookup_loaded_svms ls_svm_apply with
| Some t -> Unix.realpath t.filename
| None -> invalid_arg "No SVM model found in task"
in
{ svm_filename; predicate_kind; eps }
| _ ->
(* No other term is supported. *)
failwith "Unsupported term by SAVer"
let svm = Re__Core.(compile (str "%{svm}"))
let dataset = Re__Core.(compile (str "%{dataset}"))
let epsilon = Re__Core.(compile (str "%{epsilon}"))
let abstraction = Re__Core.(compile (str "%{abstraction}"))
let distance = Re__Core.(compile (str "%{distance}"))
let build_command config_prover dataset_filename interpreted_task =
let command = Whyconf.get_complete_command ~with_steps:false config_prover in
let params =
[
(svm, interpreted_task.svm_filename);
(dataset, Unix.realpath dataset_filename);
(epsilon, interpreted_task.eps);
(distance, "l_inf");
(abstraction, "hybrid");
]
in
List.fold params ~init:command ~f:(fun cmd (param, by) ->
Re__Core.replace_string param ~by cmd)
type answer = {
prover_answer : Call_provers.prover_answer;
nb_total : int;
nb_proved : int;
}
(* SAVer output is made of 7 columns separated by (multiple) space(s) of the
following form:
[SUMMARY] integer float float integer integer integer
We are interested in recovering the integers. The following regexp matches
groups of one or more digits by means of '(\\d+)'. The latter is used 4 times
for the 4 columns of integers we are interested in. The 1st group reports the
number of total data points in the dataset, 2nd group reports the number of
correct data points, 3rd group reports the number of robust data points, 4th
group reports the number of conditionally robust (ie, correct and robust)
data points. *)
let re_saver_output =
Re__Pcre.regexp
"\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*(\\d)+\\s*(\\d+)\\s*(\\d+)"
(* The dataset size is the first group of digits matched by
[re_saver_output]. *)
let re_group_number_dataset_size = 1
(* Regexp group number as matched by [re_saver_output] for each kind of
predicate. *)
let re_group_number_of_predicate_kind predicate_kind =
match predicate_kind with Correct -> 2 | Robust -> 3 | CondRobust -> 4
let build_answer pred_kind prover_result =
match prover_result.Call_provers.pr_answer with
| Call_provers.HighFailure -> (
let pr_output = prover_result.pr_output in
match Re__Core.exec_opt re_saver_output pr_output with
| Some re_group ->
let nb_total =
Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size)
in
let nb_proved =
let re_group_number = re_group_number_of_predicate_kind pred_kind in
Int.of_string (Re__Core.Group.get re_group re_group_number)
in
let prover_answer =
if nb_total = nb_proved
then Call_provers.Valid
else Call_provers.Invalid
in
{ prover_answer; nb_total; nb_proved }
| None -> failwith "Cannot interpret the output provided by SAVer")
| _ ->
(* Any other answer than HighFailure should never happen as we do not define
anything in SAVer's driver. *)
assert false
let call limit config env config_prover ~dataset task =
let interpreted_task = interpret_task env task in
let command = build_command config_prover dataset interpreted_task in
let prover_call =
let res_parser =
{
Call_provers.prp_regexps =
[ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
prp_timeregexps = [];
prp_stepregexps = [];
prp_exitcodes = [];
prp_model_parser = Model_parser.lookup_model_parser "no_model";
}
in
Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
~filename:" " ~get_counterexmp:false ~gen_new_file:false
~printing_info:Printer.default_printing_info (Buffer.create 10)
in
let prover_result = Call_provers.wait_on_call prover_call in
build_answer interpreted_task.predicate_kind prover_result
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
type answer = {
prover_answer : Call_provers.prover_answer;
nb_total : int; (* Total number of data points. *)
nb_proved : int; (* Number of data points verifying the property. *)
}
val call :
Call_provers.resource_limit ->
Whyconf.main ->
Env.env ->
Whyconf.config_prover ->
dataset:string ->
Task.task ->
answer
......@@ -61,11 +61,6 @@ let create_env ?(debug = false) loadpath =
config )
let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}"))
let svm = Re__Core.(compile (str "%{svm}"))
let dataset = Re__Core.(compile (str "%{dataset}"))
let epsilon = Re__Core.(compile (str "%{epsilon}"))
let abstraction = Re__Core.(compile (str "%{abstraction}"))
let distance = Re__Core.(compile (str "%{distance}"))
let combine_prover_answers answers =
List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r ->
......@@ -74,86 +69,15 @@ let combine_prover_answers answers =
| _ -> acc)
let answer_saver limit config env config_prover dataset_csv task =
let handle_task_saver task env dataset_csv command =
let dataset_filename =
match dataset_csv with
| None -> invalid_arg "No dataset provided for SAVer"
| Some s -> s
in
let robust_predicate =
let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in
Theory.ns_find_ls dataset_th.mod_theory.th_export [ "robust" ]
in
let goal = Task.task_goal_fmla task in
let eps, svm_filename =
match goal.t_node with
| Term.Tapp
( ls,
[
{ t_node = Tapp (ls_svm_apply, _); _ };
_dataset;
{ t_node = Tconst e; _ };
] ) ->
if Term.ls_equal ls robust_predicate
then
let eps = Fmt.str "%a" Constant.print_def e in
let svm_filename =
match Language.lookup_loaded_svms ls_svm_apply with
| Some t -> t.filename
| None -> invalid_arg "No SVM model found in task"
in
(eps, svm_filename)
else failwith "Wrong predicate found"
| _ ->
(* no other predicate than robust_to is supported *)
failwith "Unsupported predicate by SAVer"
in
let svm_file = Unix.realpath svm_filename in
let dataset_file = Unix.realpath dataset_filename in
let command = Re__Core.replace_string svm ~by:svm_file command in
let command = Re__Core.replace_string dataset ~by:dataset_file command in
let command = Re__Core.replace_string epsilon ~by:eps command in
let command = Re__Core.replace_string distance ~by:"l_inf" command in
let command = Re__Core.replace_string abstraction ~by:"hybrid" command in
command
in
let command = Whyconf.get_complete_command ~with_steps:false config_prover in
let command = handle_task_saver task env dataset_csv command in
let res_parser =
{
Call_provers.prp_regexps =
[ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
prp_timeregexps = [];
prp_stepregexps = [];
prp_exitcodes = [];
prp_model_parser = Model_parser.lookup_model_parser "no_model";
}
let dataset =
match dataset_csv with
| None -> invalid_arg "No dataset provided for SAVer"
| Some filename -> filename
in
let prover_call =
Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
~filename:" " ~get_counterexmp:false ~gen_new_file:false
~printing_info:Printer.default_printing_info (Buffer.create 10)
in
let prover_result = Call_provers.wait_on_call prover_call in
match prover_result.pr_answer with
| Call_provers.HighFailure -> (
let pr_output = prover_result.pr_output in
let matcher =
Re__Pcre.regexp
"\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d"
in
match Re__Core.exec_opt matcher pr_output with
| Some g ->
if Int.of_string (Re__Core.Group.get g 1)
= Int.of_string (Re__Core.Group.get g 2)
then Call_provers.Valid
else Call_provers.Invalid
| None -> Call_provers.HighFailure)
| _ ->
(* Any other answer than HighFailure should never happen as we do not define
anything in SAVer's driver. *)
assert false
let answer = Saver.call limit config env config_prover ~dataset task in
let prover_answer = answer.prover_answer in
let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in
(prover_answer, Some additional_info)
let answer_generic limit config prover config_prover driver task =
let task_prepared = Driver.prepare_task driver task in
......@@ -182,11 +106,11 @@ let answer_generic limit config prover config_prover driver task =
prover_result.pr_answer
in
let answers = List.map tasks ~f:call_prover_on_task in
let answer = combine_prover_answers answers in
answer
let prover_answer = combine_prover_answers answers in
(prover_answer, None)
let call_prover ~limit config env prover config_prover driver dataset_csv task =
let prover_answer =
let prover_answer, additional_info =
match prover with
| Prover.Saver ->
answer_saver limit config env config_prover dataset_csv task
......@@ -194,8 +118,10 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
answer_generic limit config prover config_prover driver task
in
Logs.app (fun m ->
m "@[Goal %a:@ %a@]" Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer prover_answer)
m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer prover_answer
Fmt.(option ~none:nop (any " " ++ string))
additional_info)
let verify ?(debug = false) format loadpath ?memlimit ?timeout prover
?dataset_csv file =
......
......@@ -57,15 +57,33 @@ theory DataSet
constant dataset: dataset
predicate linfty_distance (n: int) (a: features) (b: features) (eps: t) =
forall i: int. 0 <= i < n -> .- eps .< a[i] .- b[i] .< eps
predicate linfty_distance (a: features) (b: features) (eps: t) =
a.length = b.length /\
forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps
predicate correct_at (m: model) (d: datum) (eps: t) =
forall x': features.
let (x, y) = d in
linfty_distance x x' eps ->
y = predict m x'
predicate robust_at (m: model) (d: datum) (eps: t) =
forall x': features.
let (x, _) = d in
linfty_distance x x' eps ->
predict m x = predict m x'
predicate cond_robust_at (m: model) (d: datum) (eps: t) =
correct_at m d eps /\ robust_at m d eps
predicate correct (m: model) (d: dataset) (eps: t) =
forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] eps
predicate robust (m: model) (d: dataset) (eps: t) =
forall i: int. 0 <= i < d.data.length ->
forall x': features.
let (x, _) = d.data[i] in
linfty_distance d.nb_features x x' eps ->
predict m x = predict m x'
forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps
predicate cond_robust (m: model) (d: dataset) (eps: t) =
correct m d eps /\ robust m d eps
end
theory SVM
......
......@@ -9,5 +9,5 @@ case $1 in
echo "SVM: $1"
echo "Goal:"
cat $2
echo "Unknown"
echo "[SUMMARY] 2 8 0.017000 1 2 0"
esac
......@@ -29,9 +29,15 @@ Test verify
> theory T
> use TestSVM.SVMasArray
> use ieee_float.Float64
> use int.Int
> use array.Array
> use caisar.DataSet
>
> goal G: robust svm_apply dataset (8.0:t)
> goal H: correct svm_apply dataset (8.0:t)
> goal I: cond_robust svm_apply dataset (8.0:t)
> end
> EOF
[caisar] Goal G: High failure
[caisar] Goal G: Valid (2/2)
[caisar] Goal H: Invalid (1/2)
[caisar] Goal I: Invalid (0/2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment