diff --git a/lib/ovo/ovo.ml b/lib/ovo/ovo.ml index fcc630820d6a213ef9895b7146f7571d9dbf4678..f05331d3c1ff8ea78ae6466bc75741c892b30371 100644 --- a/lib/ovo/ovo.ml +++ b/lib/ovo/ovo.ml @@ -40,51 +40,48 @@ let handle_ovo_line ~f in_channel = ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) (Csv.next in_channel) -(* Skip the header part, ie comments, of the OVO format. *) -let skip_ovo_header filename in_channel = - let exception End_of_header in - let pos_in = ref (Stdlib.pos_in in_channel) in - try - while true do - let line = Stdlib.input_line in_channel in - if not (Str.string_match (Str.regexp "//") line 0) - then raise End_of_header - else pos_in := Stdlib.pos_in in_channel - done; - assert false - with - | End_of_header -> - (* At this point we have read one line past the header part: seek back. *) - Stdlib.seek_in in_channel !pos_in; - Ok () - | End_of_file -> - Error (Format.sprintf "OVO model not found in file `%s'." filename) +(* Handle ovo first line: either 'ovo' or 'ovo x y' with 'x' and 'y' positive + integer numbers. *) +let handle_ovo_first_line in_channel = + let ovo_format_error_on_first_line = ovo_format_error "first line" in + match Csv.next in_channel with + | [ first_line ] -> ( + let in_channel = Csv.of_string ~separator:' ' first_line in + match Csv.next in_channel with + | [ "ovo"; n_is; n_os ] -> + let n_is = Int.of_string (String.strip n_is) in + let n_os = Int.of_string (String.strip n_os) in + Ok (Some (n_is, n_os)) + | [ "ovo" ] -> Ok None + | _ -> ovo_format_error_on_first_line + | exception End_of_file -> ovo_format_error_on_first_line) + | _ -> ovo_format_error_on_first_line + | exception End_of_file -> ovo_format_error_on_first_line (* Retrieve inputs and outputs size. *) -let handle_ovo_basic_info in_channel = +let handle_ovo_basic_info ~descr in_channel = match handle_ovo_line ~f:Int.of_string in_channel with | [ dim ] -> Ok dim - | _ -> ovo_format_error "first" - | exception End_of_file -> ovo_format_error "first" - -(* Skip unused flag. *) -let handle_ovo_unused_flag in_channel = - try - let _ = Csv.next in_channel in - Ok () - with End_of_file -> ovo_format_error "second" + | _ -> ovo_format_error descr + | exception End_of_file -> ovo_format_error descr (* Retrieves [filename] OVO model metadata and weights wrt OVO format specification, which is described here: - https://github.com/abstract-machine-learning/saver#classifier-format. *) -let parse_in_channel filename in_channel = + https://github.com/abstract-machine-learning/saver#classifier-format. In + practice, the first line may specify the input/output size values or not. In + the latter case, we search for input/output size values in the second/third + lines respectively. *) +let parse_in_channel in_channel = let open Result in try - skip_ovo_header filename in_channel >>= fun () -> let in_channel = Csv.of_channel in_channel in - handle_ovo_unused_flag in_channel >>= fun () -> - handle_ovo_basic_info in_channel >>= fun n_is -> - handle_ovo_basic_info in_channel >>= fun n_os -> + (handle_ovo_first_line in_channel >>= function + | Some (n_is, n_os) -> Ok (n_is, n_os) + | None -> + handle_ovo_basic_info ~descr:"input size" in_channel >>= fun n_is -> + handle_ovo_basic_info ~descr:"output size" in_channel >>= fun n_os -> + Ok (n_is, n_os)) + >>= fun (n_is, n_os) -> Csv.close_in in_channel; Ok { n_inputs = n_is; n_outputs = n_os } with @@ -96,4 +93,4 @@ let parse filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) - (fun () -> parse_in_channel filename in_channel) + (fun () -> parse_in_channel in_channel) diff --git a/src/SAVer.ml b/src/SAVer.ml new file mode 100644 index 0000000000000000000000000000000000000000..64958a1bbb925d55da085aca7279d917059f47c9 --- /dev/null +++ b/src/SAVer.ml @@ -0,0 +1,133 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +let svm = Re__Core.(compile (str "%{svm}")) +let dataset = Re__Core.(compile (str "%{dataset}")) +let epsilon = Re__Core.(compile (str "%{epsilon}")) +let abstraction = Re__Core.(compile (str "%{abstraction}")) +let distance = Re__Core.(compile (str "%{distance}")) + +let build_command config_prover svm_filename dataset_filename eps = + let command = Whyconf.get_complete_command ~with_steps:false config_prover in + let params = + [ + (svm, Unix.realpath svm_filename); + (dataset, Unix.realpath dataset_filename); + (epsilon, Option.(value (map ~f:Dataset.string_of_eps eps)) ~default:"0"); + (distance, "l_inf"); + (abstraction, "hybrid"); + ] + in + List.fold params ~init:command ~f:(fun cmd (param, by) -> + Re__Core.replace_string param ~by cmd) + +type answer = { + prover_answer : Call_provers.prover_answer; + nb_total : int; + nb_proved : int; +} + +(* SAVer output is made of 7 columns separated by (multiple) space(s) of the + following form: + + [SUMMARY] integer float float integer integer integer + + We are interested in recovering the integers. The following regexp matches + groups of one or more digits by means of '(\\d+)'. The latter is used 4 times + for the 4 columns of integers we are interested in. The 1st group reports the + number of total data points in the dataset, 2nd group reports the number of + correct data points, 3rd group reports the number of robust data points, 4th + group reports the number of conditionally robust (ie, correct and robust) + data points. *) +let re_saver_output = + Re__Pcre.regexp + "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*(\\d)+\\s*(\\d+)\\s*(\\d+)" + +(* The dataset size is the first group of digits matched by + [re_saver_output]. *) +let re_group_number_dataset_size = 1 + +(* Regexp group number as matched by [re_saver_output] for each predicate. *) +let re_group_number_of_predicate = function + | Dataset.Correct -> 2 + | Robust _ -> 3 + | CondRobust _ -> 4 + +let negative_prover_answer_of_predicate = function + | Dataset.Correct -> Call_provers.Invalid + | Robust _ | CondRobust _ -> Call_provers.Unknown "" + +let build_answer predicate prover_result = + match prover_result.Call_provers.pr_answer with + | Call_provers.HighFailure -> ( + let pr_output = prover_result.pr_output in + match Re__Core.exec_opt re_saver_output pr_output with + | Some re_group -> + let nb_total = + Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size) + in + let nb_proved = + let re_group_number = re_group_number_of_predicate predicate in + Int.of_string (Re__Core.Group.get re_group re_group_number) + in + let prover_answer = + if nb_total = nb_proved + then Call_provers.Valid + else negative_prover_answer_of_predicate predicate + in + { prover_answer; nb_total; nb_proved } + | None -> failwith "Cannot interpret the output provided by SAVer") + | _ -> + (* Any other answer than HighFailure should never happen as we do not define + anything in SAVer's driver. *) + assert false + +let call limit config env config_prover ~dataset task = + let dataset_task = Dataset.interpret env Language.lookup_loaded_svms task in + let svm = dataset_task.model.filename in + let eps = + match dataset_task.predicate with + | Dataset.Correct -> None + | Robust e | CondRobust e -> Some e + in + let command = build_command config_prover svm dataset eps in + let prover_call = + let res_parser = + { + Call_provers.prp_regexps = + [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ]; + prp_timeregexps = []; + prp_stepregexps = []; + prp_exitcodes = []; + prp_model_parser = Model_parser.lookup_model_parser "no_model"; + } + in + Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config) + ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser + ~filename:" " ~get_counterexmp:false ~gen_new_file:false + ~printing_info:Printer.default_printing_info (Buffer.create 10) + in + let prover_result = Call_provers.wait_on_call prover_call in + build_answer dataset_task.predicate prover_result diff --git a/src/SAVer.mli b/src/SAVer.mli new file mode 100644 index 0000000000000000000000000000000000000000..2119b52d22cdacc5703155d261141de6b073267c --- /dev/null +++ b/src/SAVer.mli @@ -0,0 +1,38 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +type answer = { + prover_answer : Call_provers.prover_answer; + nb_total : int; (** Total number of data points. *) + nb_proved : int; (** Number of data points verifying the property. *) +} + +val call : + Call_provers.resource_limit -> + Whyconf.main -> + Env.env -> + Whyconf.config_prover -> + dataset:string -> + Task.task -> + answer diff --git a/src/dataset.ml b/src/dataset.ml new file mode 100644 index 0000000000000000000000000000000000000000..f0bd3383501d9dcc5dfdc94135c2b172635c525e --- /dev/null +++ b/src/dataset.ml @@ -0,0 +1,73 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +type eps = Constant.constant +type predicate = Correct | Robust of eps | CondRobust of eps +type 'a task = { model : 'a; predicate : predicate } + +let string_of_eps eps = Fmt.str "%a" Constant.print_def eps + +let find_predicate_ls env p = + let dataset_th = + Pmodule.read_module env [ "caisar" ] "DataSetClassification" + in + Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ] + +let failwith_unsupported_term t = + failwith (Fmt.str "Unsupported term in %a" Pretty.print_term t) + +let failwith_unsupported_ls ls = + failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls) + +let interpret env lookup task = + let term = Task.task_goal_fmla task in + match term.t_node with + | Term.Tapp (ls, { t_node = Tapp (ls_svm_apply, _); _ } :: _dataset :: tt) -> + let predicate = + match tt with + | [] -> + let correct_predicate = find_predicate_ls env "correct" in + if Term.ls_equal ls correct_predicate + then Correct + else failwith_unsupported_ls ls + | [ { t_node = Tconst e; _ } ] -> + let robust_predicate = find_predicate_ls env "robust" in + let cond_robust_predicate = find_predicate_ls env "cond_robust" in + if Term.ls_equal ls robust_predicate + then Robust e + else if Term.ls_equal ls cond_robust_predicate + then CondRobust e + else failwith_unsupported_ls ls + | _ -> failwith_unsupported_term term + in + let model = + match lookup ls_svm_apply with + | Some t -> t + | None -> invalid_arg "No model found in task" + in + { model; predicate } + | _ -> + (* No other term node is supported. *) + failwith_unsupported_term term diff --git a/src/dataset.mli b/src/dataset.mli new file mode 100644 index 0000000000000000000000000000000000000000..e80c775cc044b33e3fde7c1a682e2dea25e6f00a --- /dev/null +++ b/src/dataset.mli @@ -0,0 +1,30 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +type eps +type predicate = private Correct | Robust of eps | CondRobust of eps +type 'a task = private { model : 'a; predicate : predicate } + +val string_of_eps : eps -> string +val interpret : Env.env -> (Term.lsymbol -> 'a option) -> Task.task -> 'a task diff --git a/src/verification.ml b/src/verification.ml index b08b2d1662037ac3770f27fc6772be018864c1a2..836ff1bcb8c628832d6ed9732a2b28518451ac92 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -61,11 +61,6 @@ let create_env ?(debug = false) loadpath = config ) let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}")) -let svm = Re__Core.(compile (str "%{svm}")) -let dataset = Re__Core.(compile (str "%{dataset}")) -let epsilon = Re__Core.(compile (str "%{epsilon}")) -let abstraction = Re__Core.(compile (str "%{abstraction}")) -let distance = Re__Core.(compile (str "%{distance}")) let combine_prover_answers answers = List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r -> @@ -74,94 +69,15 @@ let combine_prover_answers answers = | _ -> acc) let answer_saver limit config env config_prover dataset_csv task = - let handle_task_saver task env dataset_csv command = - let dataset_filename = - match dataset_csv with - | None -> invalid_arg "No dataset provided for SAVer" - | Some s -> s - in - let goal = Task.task_goal_fmla task in - let eps, svm_filename = - match goal.t_node with - | Tquant (Tforall, b) -> ( - let _, _, pred = Term.t_open_quant b in - let svm_t = Pmodule.read_module env [ "caisar" ] "SVM" in - let robust_to_predicate = - let open Theory in - ns_find_ls svm_t.mod_theory.th_export [ "robust_to" ] - in - match pred.t_node with - | Term.Tapp - ( ls, - [ - { t_node = Tapp (svm_app_sym, _); _ }; - _; - { t_node = Tconst e; _ }; - ] ) -> - if Term.ls_equal ls robust_to_predicate - then - let eps = Fmt.str "%a" Constant.print_def e in - let svm_filename = - match Language.lookup_loaded_svms svm_app_sym with - | Some t -> t.filename - | None -> invalid_arg "No SVM model found in task" - in - (eps, svm_filename) - else failwith "Wrong predicate found" - | _ -> - (* no other predicate than robust_to is supported *) - failwith "Unsupported predicate by SAVer") - | _ -> failwith "Unsupported predicate by SAVer" - in - let svm_file = Unix.realpath svm_filename in - let dataset_file = Unix.realpath dataset_filename in - let command = Re__Core.replace_string svm ~by:svm_file command in - let command = Re__Core.replace_string dataset ~by:dataset_file command in - let command = Re__Core.replace_string epsilon ~by:eps command in - let command = Re__Core.replace_string distance ~by:"l_inf" command in - let command = Re__Core.replace_string abstraction ~by:"hybrid" command in - command - in - let command = Whyconf.get_complete_command ~with_steps:false config_prover in - let command = handle_task_saver task env dataset_csv command in - let res_parser = - { - Call_provers.prp_regexps = - [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ]; - prp_timeregexps = []; - prp_stepregexps = []; - prp_exitcodes = []; - prp_model_parser = Model_parser.lookup_model_parser "no_model"; - } - in - let prover_call = - Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config) - ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser - ~filename:" " ~get_counterexmp:false ~gen_new_file:false - ~printing_info:Printer.default_printing_info (Buffer.create 10) - in - let prover_result = Call_provers.wait_on_call prover_call in - let answer = - match prover_result.pr_answer with - | Call_provers.HighFailure -> ( - let pr_output = prover_result.pr_output in - let matcher = - Re__Pcre.regexp - "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d" - in - match Re__Core.exec_opt matcher pr_output with - | Some g -> - if Int.of_string (Re__Core.Group.get g 1) - = Int.of_string (Re__Core.Group.get g 2) - then Call_provers.Valid - else Call_provers.Invalid - | None -> Call_provers.HighFailure) - | _ -> - (* Any other answer than HighFailure should never happen as we do not - define anything in SAVer's driver. *) - assert false + let dataset = + match dataset_csv with + | None -> invalid_arg "No dataset provided for SAVer" + | Some filename -> filename in - answer + let answer = SAVer.call limit config env config_prover ~dataset task in + let prover_answer = answer.prover_answer in + let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in + (prover_answer, Some additional_info) let answer_generic limit config prover config_prover driver task = let task_prepared = Driver.prepare_task driver task in @@ -190,11 +106,11 @@ let answer_generic limit config prover config_prover driver task = prover_result.pr_answer in let answers = List.map tasks ~f:call_prover_on_task in - let answer = combine_prover_answers answers in - answer + let prover_answer = combine_prover_answers answers in + (prover_answer, None) let call_prover ~limit config env prover config_prover driver dataset_csv task = - let prover_answer = + let prover_answer, additional_info = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset_csv task @@ -202,8 +118,10 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = answer_generic limit config prover config_prover driver task in Logs.app (fun m -> - m "@[Goal %a:@ %a@]" Pretty.print_pr (Task.task_goal task) - Call_provers.print_prover_answer prover_answer) + m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) + Call_provers.print_prover_answer prover_answer + Fmt.(option ~none:nop (any " " ++ string)) + additional_info) let verify ?(debug = false) format loadpath ?memlimit ?timeout prover ?dataset_csv file = diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 5f5f3d336f05ff25b3d2d344be8692adac6d8e0d..1eb5c6da5c7800551b058762ad9acb7118a9397e 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -25,22 +25,66 @@ theory NN type input_type = t end -theory SVM +theory Model use ieee_float.Float64 use int.Int + use array.Array + + type model = { + nb_inputs: int; + nb_outputs: int; + } + + function predict: model -> array t -> int +end + +theory DataSetClassification + use ieee_float.Float64 + use int.Int + use array.Array + use Model + + type features = array t + type class = int - type input_type = int -> t - type output_type = int + type datum = (features, class) - type svm = { - apply : input_type -> output_type; - nb_inputs : int; - nb_classes : int; + type dataset = { + nb_features: int; + nb_classes: int; + data: array datum } - predicate linfty_distance (a: input_type) (b: input_type) (eps:t) (n: int) = - forall i. 0 <= i < n -> .- eps .< a i .- b i .< eps + constant dataset: dataset - predicate robust_to (svm: svm) (a: input_type) (eps: t) = - forall b. linfty_distance a b eps svm.nb_inputs -> svm.apply a = svm.apply b + predicate linfty_distance (a: features) (b: features) (eps: t) = + a.length = b.length /\ + forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps + + predicate correct_at (m: model) (d: datum) = + let (x, y) = d in + y = predict m x + + predicate robust_at (m: model) (d: datum) (eps: t) = + forall x': features. + let (x, _) = d in + linfty_distance x x' eps -> + predict m x = predict m x' + + predicate cond_robust_at (m: model) (d: datum) (eps: t) = + correct_at m d /\ robust_at m d eps + + predicate correct (m: model) (d: dataset) = + forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] + + predicate robust (m: model) (d: dataset) (eps: t) = + forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps + + predicate cond_robust (m: model) (d: dataset) (eps: t) = + correct m d /\ robust m d eps +end + +theory SVM + use Model + type svm = model end diff --git a/tests/bin/saver b/tests/bin/saver index afa49d80857de8f00b617f683155bf1cfb5103fc..012a62a7bb831fc4a90e2eb683195dacb35e1901 100644 --- a/tests/bin/saver +++ b/tests/bin/saver @@ -9,5 +9,5 @@ case $1 in echo "SVM: $1" echo "Goal:" cat $2 - echo "Unknown" + echo "[SUMMARY] 2 8 0.017000 1 2 0" esac diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t index 34a460815ce62de6965ca880ee0d908c055336a1..120aab2d099ffa527f9d8aae7425911ec5eb3855 100644 --- a/tests/simple_ovo.t +++ b/tests/simple_ovo.t @@ -29,9 +29,14 @@ Test verify > theory T > use TestSVM.SVMasArray > use ieee_float.Float64 - > use caisar.SVM + > use int.Int + > use caisar.DataSetClassification > - > goal G: forall a : input_type. robust_to svm_apply a (8.0:t) + > goal G: robust svm_apply dataset (8.0:t) + > goal H: correct svm_apply dataset + > goal I: cond_robust svm_apply dataset (8.0:t) > end > EOF - [caisar] Goal G: High failure + [caisar] Goal G: Valid (2/2) + [caisar] Goal H: Invalid (1/2) + [caisar] Goal I: Unknown () (0/2)