Skip to content
Snippets Groups Projects
Commit 43d11f5a authored by Michele Alberti's avatar Michele Alberti
Browse files

Merge branch 'feature/michele/extend-saver-support' into 'master'

Extend SAVer support

See merge request laiser/caisar!37
parents b171a3f9 93097cb8
No related branches found
No related tags found
No related merge requests found
......@@ -40,51 +40,48 @@ let handle_ovo_line ~f in_channel =
~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Csv.next in_channel)
(* Skip the header part, ie comments, of the OVO format. *)
let skip_ovo_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "OVO model not found in file `%s'." filename)
(* Handle ovo first line: either 'ovo' or 'ovo x y' with 'x' and 'y' positive
integer numbers. *)
let handle_ovo_first_line in_channel =
let ovo_format_error_on_first_line = ovo_format_error "first line" in
match Csv.next in_channel with
| [ first_line ] -> (
let in_channel = Csv.of_string ~separator:' ' first_line in
match Csv.next in_channel with
| [ "ovo"; n_is; n_os ] ->
let n_is = Int.of_string (String.strip n_is) in
let n_os = Int.of_string (String.strip n_os) in
Ok (Some (n_is, n_os))
| [ "ovo" ] -> Ok None
| _ -> ovo_format_error_on_first_line
| exception End_of_file -> ovo_format_error_on_first_line)
| _ -> ovo_format_error_on_first_line
| exception End_of_file -> ovo_format_error_on_first_line
(* Retrieve inputs and outputs size. *)
let handle_ovo_basic_info in_channel =
let handle_ovo_basic_info ~descr in_channel =
match handle_ovo_line ~f:Int.of_string in_channel with
| [ dim ] -> Ok dim
| _ -> ovo_format_error "first"
| exception End_of_file -> ovo_format_error "first"
(* Skip unused flag. *)
let handle_ovo_unused_flag in_channel =
try
let _ = Csv.next in_channel in
Ok ()
with End_of_file -> ovo_format_error "second"
| _ -> ovo_format_error descr
| exception End_of_file -> ovo_format_error descr
(* Retrieves [filename] OVO model metadata and weights wrt OVO format
specification, which is described here:
https://github.com/abstract-machine-learning/saver#classifier-format. *)
let parse_in_channel filename in_channel =
https://github.com/abstract-machine-learning/saver#classifier-format. In
practice, the first line may specify the input/output size values or not. In
the latter case, we search for input/output size values in the second/third
lines respectively. *)
let parse_in_channel in_channel =
let open Result in
try
skip_ovo_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_ovo_unused_flag in_channel >>= fun () ->
handle_ovo_basic_info in_channel >>= fun n_is ->
handle_ovo_basic_info in_channel >>= fun n_os ->
(handle_ovo_first_line in_channel >>= function
| Some (n_is, n_os) -> Ok (n_is, n_os)
| None ->
handle_ovo_basic_info ~descr:"input size" in_channel >>= fun n_is ->
handle_ovo_basic_info ~descr:"output size" in_channel >>= fun n_os ->
Ok (n_is, n_os))
>>= fun (n_is, n_os) ->
Csv.close_in in_channel;
Ok { n_inputs = n_is; n_outputs = n_os }
with
......@@ -96,4 +93,4 @@ let parse filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel filename in_channel)
(fun () -> parse_in_channel in_channel)
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
open Base
let svm = Re__Core.(compile (str "%{svm}"))
let dataset = Re__Core.(compile (str "%{dataset}"))
let epsilon = Re__Core.(compile (str "%{epsilon}"))
let abstraction = Re__Core.(compile (str "%{abstraction}"))
let distance = Re__Core.(compile (str "%{distance}"))
let build_command config_prover svm_filename dataset_filename eps =
let command = Whyconf.get_complete_command ~with_steps:false config_prover in
let params =
[
(svm, Unix.realpath svm_filename);
(dataset, Unix.realpath dataset_filename);
(epsilon, Option.(value (map ~f:Dataset.string_of_eps eps)) ~default:"0");
(distance, "l_inf");
(abstraction, "hybrid");
]
in
List.fold params ~init:command ~f:(fun cmd (param, by) ->
Re__Core.replace_string param ~by cmd)
type answer = {
prover_answer : Call_provers.prover_answer;
nb_total : int;
nb_proved : int;
}
(* SAVer output is made of 7 columns separated by (multiple) space(s) of the
following form:
[SUMMARY] integer float float integer integer integer
We are interested in recovering the integers. The following regexp matches
groups of one or more digits by means of '(\\d+)'. The latter is used 4 times
for the 4 columns of integers we are interested in. The 1st group reports the
number of total data points in the dataset, 2nd group reports the number of
correct data points, 3rd group reports the number of robust data points, 4th
group reports the number of conditionally robust (ie, correct and robust)
data points. *)
let re_saver_output =
Re__Pcre.regexp
"\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*(\\d)+\\s*(\\d+)\\s*(\\d+)"
(* The dataset size is the first group of digits matched by
[re_saver_output]. *)
let re_group_number_dataset_size = 1
(* Regexp group number as matched by [re_saver_output] for each predicate. *)
let re_group_number_of_predicate = function
| Dataset.Correct -> 2
| Robust _ -> 3
| CondRobust _ -> 4
let negative_prover_answer_of_predicate = function
| Dataset.Correct -> Call_provers.Invalid
| Robust _ | CondRobust _ -> Call_provers.Unknown ""
let build_answer predicate prover_result =
match prover_result.Call_provers.pr_answer with
| Call_provers.HighFailure -> (
let pr_output = prover_result.pr_output in
match Re__Core.exec_opt re_saver_output pr_output with
| Some re_group ->
let nb_total =
Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size)
in
let nb_proved =
let re_group_number = re_group_number_of_predicate predicate in
Int.of_string (Re__Core.Group.get re_group re_group_number)
in
let prover_answer =
if nb_total = nb_proved
then Call_provers.Valid
else negative_prover_answer_of_predicate predicate
in
{ prover_answer; nb_total; nb_proved }
| None -> failwith "Cannot interpret the output provided by SAVer")
| _ ->
(* Any other answer than HighFailure should never happen as we do not define
anything in SAVer's driver. *)
assert false
let call limit config env config_prover ~dataset task =
let dataset_task = Dataset.interpret env Language.lookup_loaded_svms task in
let svm = dataset_task.model.filename in
let eps =
match dataset_task.predicate with
| Dataset.Correct -> None
| Robust e | CondRobust e -> Some e
in
let command = build_command config_prover svm dataset eps in
let prover_call =
let res_parser =
{
Call_provers.prp_regexps =
[ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
prp_timeregexps = [];
prp_stepregexps = [];
prp_exitcodes = [];
prp_model_parser = Model_parser.lookup_model_parser "no_model";
}
in
Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
~filename:" " ~get_counterexmp:false ~gen_new_file:false
~printing_info:Printer.default_printing_info (Buffer.create 10)
in
let prover_result = Call_provers.wait_on_call prover_call in
build_answer dataset_task.predicate prover_result
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
type answer = {
prover_answer : Call_provers.prover_answer;
nb_total : int; (** Total number of data points. *)
nb_proved : int; (** Number of data points verifying the property. *)
}
val call :
Call_provers.resource_limit ->
Whyconf.main ->
Env.env ->
Whyconf.config_prover ->
dataset:string ->
Task.task ->
answer
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
open Base
type eps = Constant.constant
type predicate = Correct | Robust of eps | CondRobust of eps
type 'a task = { model : 'a; predicate : predicate }
let string_of_eps eps = Fmt.str "%a" Constant.print_def eps
let find_predicate_ls env p =
let dataset_th =
Pmodule.read_module env [ "caisar" ] "DataSetClassification"
in
Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ]
let failwith_unsupported_term t =
failwith (Fmt.str "Unsupported term in %a" Pretty.print_term t)
let failwith_unsupported_ls ls =
failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls)
let interpret env lookup task =
let term = Task.task_goal_fmla task in
match term.t_node with
| Term.Tapp (ls, { t_node = Tapp (ls_svm_apply, _); _ } :: _dataset :: tt) ->
let predicate =
match tt with
| [] ->
let correct_predicate = find_predicate_ls env "correct" in
if Term.ls_equal ls correct_predicate
then Correct
else failwith_unsupported_ls ls
| [ { t_node = Tconst e; _ } ] ->
let robust_predicate = find_predicate_ls env "robust" in
let cond_robust_predicate = find_predicate_ls env "cond_robust" in
if Term.ls_equal ls robust_predicate
then Robust e
else if Term.ls_equal ls cond_robust_predicate
then CondRobust e
else failwith_unsupported_ls ls
| _ -> failwith_unsupported_term term
in
let model =
match lookup ls_svm_apply with
| Some t -> t
| None -> invalid_arg "No model found in task"
in
{ model; predicate }
| _ ->
(* No other term node is supported. *)
failwith_unsupported_term term
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
type eps
type predicate = private Correct | Robust of eps | CondRobust of eps
type 'a task = private { model : 'a; predicate : predicate }
val string_of_eps : eps -> string
val interpret : Env.env -> (Term.lsymbol -> 'a option) -> Task.task -> 'a task
......@@ -61,11 +61,6 @@ let create_env ?(debug = false) loadpath =
config )
let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}"))
let svm = Re__Core.(compile (str "%{svm}"))
let dataset = Re__Core.(compile (str "%{dataset}"))
let epsilon = Re__Core.(compile (str "%{epsilon}"))
let abstraction = Re__Core.(compile (str "%{abstraction}"))
let distance = Re__Core.(compile (str "%{distance}"))
let combine_prover_answers answers =
List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r ->
......@@ -74,94 +69,15 @@ let combine_prover_answers answers =
| _ -> acc)
let answer_saver limit config env config_prover dataset_csv task =
let handle_task_saver task env dataset_csv command =
let dataset_filename =
match dataset_csv with
| None -> invalid_arg "No dataset provided for SAVer"
| Some s -> s
in
let goal = Task.task_goal_fmla task in
let eps, svm_filename =
match goal.t_node with
| Tquant (Tforall, b) -> (
let _, _, pred = Term.t_open_quant b in
let svm_t = Pmodule.read_module env [ "caisar" ] "SVM" in
let robust_to_predicate =
let open Theory in
ns_find_ls svm_t.mod_theory.th_export [ "robust_to" ]
in
match pred.t_node with
| Term.Tapp
( ls,
[
{ t_node = Tapp (svm_app_sym, _); _ };
_;
{ t_node = Tconst e; _ };
] ) ->
if Term.ls_equal ls robust_to_predicate
then
let eps = Fmt.str "%a" Constant.print_def e in
let svm_filename =
match Language.lookup_loaded_svms svm_app_sym with
| Some t -> t.filename
| None -> invalid_arg "No SVM model found in task"
in
(eps, svm_filename)
else failwith "Wrong predicate found"
| _ ->
(* no other predicate than robust_to is supported *)
failwith "Unsupported predicate by SAVer")
| _ -> failwith "Unsupported predicate by SAVer"
in
let svm_file = Unix.realpath svm_filename in
let dataset_file = Unix.realpath dataset_filename in
let command = Re__Core.replace_string svm ~by:svm_file command in
let command = Re__Core.replace_string dataset ~by:dataset_file command in
let command = Re__Core.replace_string epsilon ~by:eps command in
let command = Re__Core.replace_string distance ~by:"l_inf" command in
let command = Re__Core.replace_string abstraction ~by:"hybrid" command in
command
in
let command = Whyconf.get_complete_command ~with_steps:false config_prover in
let command = handle_task_saver task env dataset_csv command in
let res_parser =
{
Call_provers.prp_regexps =
[ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
prp_timeregexps = [];
prp_stepregexps = [];
prp_exitcodes = [];
prp_model_parser = Model_parser.lookup_model_parser "no_model";
}
in
let prover_call =
Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
~filename:" " ~get_counterexmp:false ~gen_new_file:false
~printing_info:Printer.default_printing_info (Buffer.create 10)
in
let prover_result = Call_provers.wait_on_call prover_call in
let answer =
match prover_result.pr_answer with
| Call_provers.HighFailure -> (
let pr_output = prover_result.pr_output in
let matcher =
Re__Pcre.regexp
"\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d"
in
match Re__Core.exec_opt matcher pr_output with
| Some g ->
if Int.of_string (Re__Core.Group.get g 1)
= Int.of_string (Re__Core.Group.get g 2)
then Call_provers.Valid
else Call_provers.Invalid
| None -> Call_provers.HighFailure)
| _ ->
(* Any other answer than HighFailure should never happen as we do not
define anything in SAVer's driver. *)
assert false
let dataset =
match dataset_csv with
| None -> invalid_arg "No dataset provided for SAVer"
| Some filename -> filename
in
answer
let answer = SAVer.call limit config env config_prover ~dataset task in
let prover_answer = answer.prover_answer in
let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in
(prover_answer, Some additional_info)
let answer_generic limit config prover config_prover driver task =
let task_prepared = Driver.prepare_task driver task in
......@@ -190,11 +106,11 @@ let answer_generic limit config prover config_prover driver task =
prover_result.pr_answer
in
let answers = List.map tasks ~f:call_prover_on_task in
let answer = combine_prover_answers answers in
answer
let prover_answer = combine_prover_answers answers in
(prover_answer, None)
let call_prover ~limit config env prover config_prover driver dataset_csv task =
let prover_answer =
let prover_answer, additional_info =
match prover with
| Prover.Saver ->
answer_saver limit config env config_prover dataset_csv task
......@@ -202,8 +118,10 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
answer_generic limit config prover config_prover driver task
in
Logs.app (fun m ->
m "@[Goal %a:@ %a@]" Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer prover_answer)
m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
Call_provers.print_prover_answer prover_answer
Fmt.(option ~none:nop (any " " ++ string))
additional_info)
let verify ?(debug = false) format loadpath ?memlimit ?timeout prover
?dataset_csv file =
......
......@@ -25,22 +25,66 @@ theory NN
type input_type = t
end
theory SVM
theory Model
use ieee_float.Float64
use int.Int
use array.Array
type model = {
nb_inputs: int;
nb_outputs: int;
}
function predict: model -> array t -> int
end
theory DataSetClassification
use ieee_float.Float64
use int.Int
use array.Array
use Model
type features = array t
type class = int
type input_type = int -> t
type output_type = int
type datum = (features, class)
type svm = {
apply : input_type -> output_type;
nb_inputs : int;
nb_classes : int;
type dataset = {
nb_features: int;
nb_classes: int;
data: array datum
}
predicate linfty_distance (a: input_type) (b: input_type) (eps:t) (n: int) =
forall i. 0 <= i < n -> .- eps .< a i .- b i .< eps
constant dataset: dataset
predicate robust_to (svm: svm) (a: input_type) (eps: t) =
forall b. linfty_distance a b eps svm.nb_inputs -> svm.apply a = svm.apply b
predicate linfty_distance (a: features) (b: features) (eps: t) =
a.length = b.length /\
forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps
predicate correct_at (m: model) (d: datum) =
let (x, y) = d in
y = predict m x
predicate robust_at (m: model) (d: datum) (eps: t) =
forall x': features.
let (x, _) = d in
linfty_distance x x' eps ->
predict m x = predict m x'
predicate cond_robust_at (m: model) (d: datum) (eps: t) =
correct_at m d /\ robust_at m d eps
predicate correct (m: model) (d: dataset) =
forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i]
predicate robust (m: model) (d: dataset) (eps: t) =
forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps
predicate cond_robust (m: model) (d: dataset) (eps: t) =
correct m d /\ robust m d eps
end
theory SVM
use Model
type svm = model
end
......@@ -9,5 +9,5 @@ case $1 in
echo "SVM: $1"
echo "Goal:"
cat $2
echo "Unknown"
echo "[SUMMARY] 2 8 0.017000 1 2 0"
esac
......@@ -29,9 +29,14 @@ Test verify
> theory T
> use TestSVM.SVMasArray
> use ieee_float.Float64
> use caisar.SVM
> use int.Int
> use caisar.DataSetClassification
>
> goal G: forall a : input_type. robust_to svm_apply a (8.0:t)
> goal G: robust svm_apply dataset (8.0:t)
> goal H: correct svm_apply dataset
> goal I: cond_robust svm_apply dataset (8.0:t)
> end
> EOF
[caisar] Goal G: High failure
[caisar] Goal G: Valid (2/2)
[caisar] Goal H: Invalid (1/2)
[caisar] Goal I: Unknown () (0/2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment