From 9718c499e8bfe4e95c36f3c50d032a7fa1cb6091 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Thu, 15 Sep 2022 17:21:06 +0200 Subject: [PATCH] [SAVer] Extend SAVer support for correct and conditional robust predicates. --- lib/ovo/ovo.ml | 69 +++++++++--------- src/saver.ml | 168 ++++++++++++++++++++++++++++++++++++++++++++ src/saver.mli | 38 ++++++++++ src/verification.ml | 104 ++++----------------------- stdlib/caisar.mlw | 32 +++++++-- tests/bin/saver | 2 +- tests/simple_ovo.t | 8 ++- 7 files changed, 287 insertions(+), 134 deletions(-) create mode 100644 src/saver.ml create mode 100644 src/saver.mli diff --git a/lib/ovo/ovo.ml b/lib/ovo/ovo.ml index fcc63082..f05331d3 100644 --- a/lib/ovo/ovo.ml +++ b/lib/ovo/ovo.ml @@ -40,51 +40,48 @@ let handle_ovo_line ~f in_channel = ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) (Csv.next in_channel) -(* Skip the header part, ie comments, of the OVO format. *) -let skip_ovo_header filename in_channel = - let exception End_of_header in - let pos_in = ref (Stdlib.pos_in in_channel) in - try - while true do - let line = Stdlib.input_line in_channel in - if not (Str.string_match (Str.regexp "//") line 0) - then raise End_of_header - else pos_in := Stdlib.pos_in in_channel - done; - assert false - with - | End_of_header -> - (* At this point we have read one line past the header part: seek back. *) - Stdlib.seek_in in_channel !pos_in; - Ok () - | End_of_file -> - Error (Format.sprintf "OVO model not found in file `%s'." filename) +(* Handle ovo first line: either 'ovo' or 'ovo x y' with 'x' and 'y' positive + integer numbers. *) +let handle_ovo_first_line in_channel = + let ovo_format_error_on_first_line = ovo_format_error "first line" in + match Csv.next in_channel with + | [ first_line ] -> ( + let in_channel = Csv.of_string ~separator:' ' first_line in + match Csv.next in_channel with + | [ "ovo"; n_is; n_os ] -> + let n_is = Int.of_string (String.strip n_is) in + let n_os = Int.of_string (String.strip n_os) in + Ok (Some (n_is, n_os)) + | [ "ovo" ] -> Ok None + | _ -> ovo_format_error_on_first_line + | exception End_of_file -> ovo_format_error_on_first_line) + | _ -> ovo_format_error_on_first_line + | exception End_of_file -> ovo_format_error_on_first_line (* Retrieve inputs and outputs size. *) -let handle_ovo_basic_info in_channel = +let handle_ovo_basic_info ~descr in_channel = match handle_ovo_line ~f:Int.of_string in_channel with | [ dim ] -> Ok dim - | _ -> ovo_format_error "first" - | exception End_of_file -> ovo_format_error "first" - -(* Skip unused flag. *) -let handle_ovo_unused_flag in_channel = - try - let _ = Csv.next in_channel in - Ok () - with End_of_file -> ovo_format_error "second" + | _ -> ovo_format_error descr + | exception End_of_file -> ovo_format_error descr (* Retrieves [filename] OVO model metadata and weights wrt OVO format specification, which is described here: - https://github.com/abstract-machine-learning/saver#classifier-format. *) -let parse_in_channel filename in_channel = + https://github.com/abstract-machine-learning/saver#classifier-format. In + practice, the first line may specify the input/output size values or not. In + the latter case, we search for input/output size values in the second/third + lines respectively. *) +let parse_in_channel in_channel = let open Result in try - skip_ovo_header filename in_channel >>= fun () -> let in_channel = Csv.of_channel in_channel in - handle_ovo_unused_flag in_channel >>= fun () -> - handle_ovo_basic_info in_channel >>= fun n_is -> - handle_ovo_basic_info in_channel >>= fun n_os -> + (handle_ovo_first_line in_channel >>= function + | Some (n_is, n_os) -> Ok (n_is, n_os) + | None -> + handle_ovo_basic_info ~descr:"input size" in_channel >>= fun n_is -> + handle_ovo_basic_info ~descr:"output size" in_channel >>= fun n_os -> + Ok (n_is, n_os)) + >>= fun (n_is, n_os) -> Csv.close_in in_channel; Ok { n_inputs = n_is; n_outputs = n_os } with @@ -96,4 +93,4 @@ let parse filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) - (fun () -> parse_in_channel filename in_channel) + (fun () -> parse_in_channel in_channel) diff --git a/src/saver.ml b/src/saver.ml new file mode 100644 index 00000000..2c563a46 --- /dev/null +++ b/src/saver.ml @@ -0,0 +1,168 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +type predicate_kind = Correct | Robust | CondRobust + +type interpreted_task = { + svm_filename : string; + predicate_kind : predicate_kind; + eps : string; +} + +let find_predicate_ls env p = + let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in + Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ] + +let interpret_predicate env ls = + let correct_predicate = find_predicate_ls env "correct" in + let robust_predicate = find_predicate_ls env "robust" in + let cond_robust_predicate = find_predicate_ls env "cond_robust" in + if Term.ls_equal ls correct_predicate + then Correct + else if Term.ls_equal ls robust_predicate + then Robust + else if Term.ls_equal ls cond_robust_predicate + then CondRobust + else failwith (Fmt.str "Unsupported by SAVer: %a" Pretty.print_ls ls) + +let interpret_task env task = + let goal = Task.task_goal_fmla task in + match goal.t_node with + | Term.Tapp + ( ls, + [ + { t_node = Tapp (ls_svm_apply, _); _ }; + _dataset; + { t_node = Tconst e; _ }; + ] ) -> + let predicate_kind = interpret_predicate env ls in + let eps = Fmt.str "%a" Constant.print_def e in + let svm_filename = + match Language.lookup_loaded_svms ls_svm_apply with + | Some t -> Unix.realpath t.filename + | None -> invalid_arg "No SVM model found in task" + in + { svm_filename; predicate_kind; eps } + | _ -> + (* No other term is supported. *) + failwith "Unsupported term by SAVer" + +let svm = Re__Core.(compile (str "%{svm}")) +let dataset = Re__Core.(compile (str "%{dataset}")) +let epsilon = Re__Core.(compile (str "%{epsilon}")) +let abstraction = Re__Core.(compile (str "%{abstraction}")) +let distance = Re__Core.(compile (str "%{distance}")) + +let build_command config_prover dataset_filename interpreted_task = + let command = Whyconf.get_complete_command ~with_steps:false config_prover in + let params = + [ + (svm, interpreted_task.svm_filename); + (dataset, Unix.realpath dataset_filename); + (epsilon, interpreted_task.eps); + (distance, "l_inf"); + (abstraction, "hybrid"); + ] + in + List.fold params ~init:command ~f:(fun cmd (param, by) -> + Re__Core.replace_string param ~by cmd) + +type answer = { + prover_answer : Call_provers.prover_answer; + nb_total : int; + nb_proved : int; +} + +(* SAVer output is made of 7 columns separated by (multiple) space(s) of the + following form: + + [SUMMARY] integer float float integer integer integer + + We are interested in recovering the integers. The following regexp matches + groups of one or more digits by means of '(\\d+)'. The latter is used 4 times + for the 4 columns of integers we are interested in. The 1st group reports the + number of total data points in the dataset, 2nd group reports the number of + correct data points, 3rd group reports the number of robust data points, 4th + group reports the number of conditionally robust (ie, correct and robust) + data points. *) +let re_saver_output = + Re__Pcre.regexp + "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*(\\d)+\\s*(\\d+)\\s*(\\d+)" + +(* The dataset size is the first group of digits matched by + [re_saver_output]. *) +let re_group_number_dataset_size = 1 + +(* Regexp group number as matched by [re_saver_output] for each kind of + predicate. *) +let re_group_number_of_predicate_kind predicate_kind = + match predicate_kind with Correct -> 2 | Robust -> 3 | CondRobust -> 4 + +let build_answer pred_kind prover_result = + match prover_result.Call_provers.pr_answer with + | Call_provers.HighFailure -> ( + let pr_output = prover_result.pr_output in + match Re__Core.exec_opt re_saver_output pr_output with + | Some re_group -> + let nb_total = + Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size) + in + let nb_proved = + let re_group_number = re_group_number_of_predicate_kind pred_kind in + Int.of_string (Re__Core.Group.get re_group re_group_number) + in + let prover_answer = + if nb_total = nb_proved + then Call_provers.Valid + else Call_provers.Invalid + in + { prover_answer; nb_total; nb_proved } + | None -> failwith "Cannot interpret the output provided by SAVer") + | _ -> + (* Any other answer than HighFailure should never happen as we do not define + anything in SAVer's driver. *) + assert false + +let call limit config env config_prover ~dataset task = + let interpreted_task = interpret_task env task in + let command = build_command config_prover dataset interpreted_task in + let prover_call = + let res_parser = + { + Call_provers.prp_regexps = + [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ]; + prp_timeregexps = []; + prp_stepregexps = []; + prp_exitcodes = []; + prp_model_parser = Model_parser.lookup_model_parser "no_model"; + } + in + Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config) + ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser + ~filename:" " ~get_counterexmp:false ~gen_new_file:false + ~printing_info:Printer.default_printing_info (Buffer.create 10) + in + let prover_result = Call_provers.wait_on_call prover_call in + build_answer interpreted_task.predicate_kind prover_result diff --git a/src/saver.mli b/src/saver.mli new file mode 100644 index 00000000..2bec2d23 --- /dev/null +++ b/src/saver.mli @@ -0,0 +1,38 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +type answer = { + prover_answer : Call_provers.prover_answer; + nb_total : int; (* Total number of data points. *) + nb_proved : int; (* Number of data points verifying the property. *) +} + +val call : + Call_provers.resource_limit -> + Whyconf.main -> + Env.env -> + Whyconf.config_prover -> + dataset:string -> + Task.task -> + answer diff --git a/src/verification.ml b/src/verification.ml index c455308f..2d12c9dd 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -61,11 +61,6 @@ let create_env ?(debug = false) loadpath = config ) let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}")) -let svm = Re__Core.(compile (str "%{svm}")) -let dataset = Re__Core.(compile (str "%{dataset}")) -let epsilon = Re__Core.(compile (str "%{epsilon}")) -let abstraction = Re__Core.(compile (str "%{abstraction}")) -let distance = Re__Core.(compile (str "%{distance}")) let combine_prover_answers answers = List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r -> @@ -74,86 +69,15 @@ let combine_prover_answers answers = | _ -> acc) let answer_saver limit config env config_prover dataset_csv task = - let handle_task_saver task env dataset_csv command = - let dataset_filename = - match dataset_csv with - | None -> invalid_arg "No dataset provided for SAVer" - | Some s -> s - in - let robust_predicate = - let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in - Theory.ns_find_ls dataset_th.mod_theory.th_export [ "robust" ] - in - let goal = Task.task_goal_fmla task in - let eps, svm_filename = - match goal.t_node with - | Term.Tapp - ( ls, - [ - { t_node = Tapp (ls_svm_apply, _); _ }; - _dataset; - { t_node = Tconst e; _ }; - ] ) -> - if Term.ls_equal ls robust_predicate - then - let eps = Fmt.str "%a" Constant.print_def e in - let svm_filename = - match Language.lookup_loaded_svms ls_svm_apply with - | Some t -> t.filename - | None -> invalid_arg "No SVM model found in task" - in - (eps, svm_filename) - else failwith "Wrong predicate found" - | _ -> - (* no other predicate than robust_to is supported *) - failwith "Unsupported predicate by SAVer" - in - let svm_file = Unix.realpath svm_filename in - let dataset_file = Unix.realpath dataset_filename in - let command = Re__Core.replace_string svm ~by:svm_file command in - let command = Re__Core.replace_string dataset ~by:dataset_file command in - let command = Re__Core.replace_string epsilon ~by:eps command in - let command = Re__Core.replace_string distance ~by:"l_inf" command in - let command = Re__Core.replace_string abstraction ~by:"hybrid" command in - command - in - let command = Whyconf.get_complete_command ~with_steps:false config_prover in - let command = handle_task_saver task env dataset_csv command in - let res_parser = - { - Call_provers.prp_regexps = - [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ]; - prp_timeregexps = []; - prp_stepregexps = []; - prp_exitcodes = []; - prp_model_parser = Model_parser.lookup_model_parser "no_model"; - } + let dataset = + match dataset_csv with + | None -> invalid_arg "No dataset provided for SAVer" + | Some filename -> filename in - let prover_call = - Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config) - ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser - ~filename:" " ~get_counterexmp:false ~gen_new_file:false - ~printing_info:Printer.default_printing_info (Buffer.create 10) - in - let prover_result = Call_provers.wait_on_call prover_call in - match prover_result.pr_answer with - | Call_provers.HighFailure -> ( - let pr_output = prover_result.pr_output in - let matcher = - Re__Pcre.regexp - "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d" - in - match Re__Core.exec_opt matcher pr_output with - | Some g -> - if Int.of_string (Re__Core.Group.get g 1) - = Int.of_string (Re__Core.Group.get g 2) - then Call_provers.Valid - else Call_provers.Invalid - | None -> Call_provers.HighFailure) - | _ -> - (* Any other answer than HighFailure should never happen as we do not define - anything in SAVer's driver. *) - assert false + let answer = Saver.call limit config env config_prover ~dataset task in + let prover_answer = answer.prover_answer in + let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in + (prover_answer, Some additional_info) let answer_generic limit config prover config_prover driver task = let task_prepared = Driver.prepare_task driver task in @@ -182,11 +106,11 @@ let answer_generic limit config prover config_prover driver task = prover_result.pr_answer in let answers = List.map tasks ~f:call_prover_on_task in - let answer = combine_prover_answers answers in - answer + let prover_answer = combine_prover_answers answers in + (prover_answer, None) let call_prover ~limit config env prover config_prover driver dataset_csv task = - let prover_answer = + let prover_answer, additional_info = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset_csv task @@ -194,8 +118,10 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = answer_generic limit config prover config_prover driver task in Logs.app (fun m -> - m "@[Goal %a:@ %a@]" Pretty.print_pr (Task.task_goal task) - Call_provers.print_prover_answer prover_answer) + m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) + Call_provers.print_prover_answer prover_answer + Fmt.(option ~none:nop (any " " ++ string)) + additional_info) let verify ?(debug = false) format loadpath ?memlimit ?timeout prover ?dataset_csv file = diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 675b1ab9..7d02870e 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -57,15 +57,33 @@ theory DataSet constant dataset: dataset - predicate linfty_distance (n: int) (a: features) (b: features) (eps: t) = - forall i: int. 0 <= i < n -> .- eps .< a[i] .- b[i] .< eps + predicate linfty_distance (a: features) (b: features) (eps: t) = + a.length = b.length /\ + forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps + + predicate correct_at (m: model) (d: datum) (eps: t) = + forall x': features. + let (x, y) = d in + linfty_distance x x' eps -> + y = predict m x' + + predicate robust_at (m: model) (d: datum) (eps: t) = + forall x': features. + let (x, _) = d in + linfty_distance x x' eps -> + predict m x = predict m x' + + predicate cond_robust_at (m: model) (d: datum) (eps: t) = + correct_at m d eps /\ robust_at m d eps + + predicate correct (m: model) (d: dataset) (eps: t) = + forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] eps predicate robust (m: model) (d: dataset) (eps: t) = - forall i: int. 0 <= i < d.data.length -> - forall x': features. - let (x, _) = d.data[i] in - linfty_distance d.nb_features x x' eps -> - predict m x = predict m x' + forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps + + predicate cond_robust (m: model) (d: dataset) (eps: t) = + correct m d eps /\ robust m d eps end theory SVM diff --git a/tests/bin/saver b/tests/bin/saver index afa49d80..012a62a7 100644 --- a/tests/bin/saver +++ b/tests/bin/saver @@ -9,5 +9,5 @@ case $1 in echo "SVM: $1" echo "Goal:" cat $2 - echo "Unknown" + echo "[SUMMARY] 2 8 0.017000 1 2 0" esac diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t index 1e02da73..cd851827 100644 --- a/tests/simple_ovo.t +++ b/tests/simple_ovo.t @@ -29,9 +29,15 @@ Test verify > theory T > use TestSVM.SVMasArray > use ieee_float.Float64 + > use int.Int + > use array.Array > use caisar.DataSet > > goal G: robust svm_apply dataset (8.0:t) + > goal H: correct svm_apply dataset (8.0:t) + > goal I: cond_robust svm_apply dataset (8.0:t) > end > EOF - [caisar] Goal G: High failure + [caisar] Goal G: Valid (2/2) + [caisar] Goal H: Invalid (1/2) + [caisar] Goal I: Invalid (0/2) -- GitLab