diff --git a/src/SAVer.ml b/src/SAVer.ml index 79fd9e2cd0eb24ddb6bf35639785e9c393c591ae..64958a1bbb925d55da085aca7279d917059f47c9 100644 --- a/src/SAVer.ml +++ b/src/SAVer.ml @@ -23,67 +23,19 @@ open Why3 open Base -type predicate_kind = Correct | Robust | CondRobust - -type interpreted_task = { - svm_filename : string; - predicate_kind : predicate_kind; - eps : string; -} - -let find_predicate_ls env p = - let dataset_th = - Pmodule.read_module env [ "caisar" ] "DataSetClassification" - in - Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ] - -let interpret_predicate env ls = - let correct_predicate = find_predicate_ls env "correct" in - let robust_predicate = find_predicate_ls env "robust" in - let cond_robust_predicate = find_predicate_ls env "cond_robust" in - if Term.ls_equal ls correct_predicate - then Correct - else if Term.ls_equal ls robust_predicate - then Robust - else if Term.ls_equal ls cond_robust_predicate - then CondRobust - else failwith (Fmt.str "Unsupported by SAVer: %a" Pretty.print_ls ls) - -let interpret_task env task = - let goal = Task.task_goal_fmla task in - match goal.t_node with - | Term.Tapp - ( ls, - [ - { t_node = Tapp (ls_svm_apply, _); _ }; - _dataset; - { t_node = Tconst e; _ }; - ] ) -> - let predicate_kind = interpret_predicate env ls in - let eps = Fmt.str "%a" Constant.print_def e in - let svm_filename = - match Language.lookup_loaded_svms ls_svm_apply with - | Some t -> Unix.realpath t.filename - | None -> invalid_arg "No SVM model found in task" - in - { svm_filename; predicate_kind; eps } - | _ -> - (* No other term is supported. *) - failwith "Unsupported term by SAVer" - let svm = Re__Core.(compile (str "%{svm}")) let dataset = Re__Core.(compile (str "%{dataset}")) let epsilon = Re__Core.(compile (str "%{epsilon}")) let abstraction = Re__Core.(compile (str "%{abstraction}")) let distance = Re__Core.(compile (str "%{distance}")) -let build_command config_prover dataset_filename interpreted_task = +let build_command config_prover svm_filename dataset_filename eps = let command = Whyconf.get_complete_command ~with_steps:false config_prover in let params = [ - (svm, interpreted_task.svm_filename); + (svm, Unix.realpath svm_filename); (dataset, Unix.realpath dataset_filename); - (epsilon, interpreted_task.eps); + (epsilon, Option.(value (map ~f:Dataset.string_of_eps eps)) ~default:"0"); (distance, "l_inf"); (abstraction, "hybrid"); ] @@ -117,12 +69,17 @@ let re_saver_output = [re_saver_output]. *) let re_group_number_dataset_size = 1 -(* Regexp group number as matched by [re_saver_output] for each kind of - predicate. *) -let re_group_number_of_predicate_kind predicate_kind = - match predicate_kind with Correct -> 2 | Robust -> 3 | CondRobust -> 4 +(* Regexp group number as matched by [re_saver_output] for each predicate. *) +let re_group_number_of_predicate = function + | Dataset.Correct -> 2 + | Robust _ -> 3 + | CondRobust _ -> 4 -let build_answer pred_kind prover_result = +let negative_prover_answer_of_predicate = function + | Dataset.Correct -> Call_provers.Invalid + | Robust _ | CondRobust _ -> Call_provers.Unknown "" + +let build_answer predicate prover_result = match prover_result.Call_provers.pr_answer with | Call_provers.HighFailure -> ( let pr_output = prover_result.pr_output in @@ -132,13 +89,13 @@ let build_answer pred_kind prover_result = Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size) in let nb_proved = - let re_group_number = re_group_number_of_predicate_kind pred_kind in + let re_group_number = re_group_number_of_predicate predicate in Int.of_string (Re__Core.Group.get re_group re_group_number) in let prover_answer = if nb_total = nb_proved then Call_provers.Valid - else Call_provers.Unknown "" + else negative_prover_answer_of_predicate predicate in { prover_answer; nb_total; nb_proved } | None -> failwith "Cannot interpret the output provided by SAVer") @@ -148,8 +105,14 @@ let build_answer pred_kind prover_result = assert false let call limit config env config_prover ~dataset task = - let interpreted_task = interpret_task env task in - let command = build_command config_prover dataset interpreted_task in + let dataset_task = Dataset.interpret env Language.lookup_loaded_svms task in + let svm = dataset_task.model.filename in + let eps = + match dataset_task.predicate with + | Dataset.Correct -> None + | Robust e | CondRobust e -> Some e + in + let command = build_command config_prover svm dataset eps in let prover_call = let res_parser = { @@ -167,4 +130,4 @@ let call limit config env config_prover ~dataset task = ~printing_info:Printer.default_printing_info (Buffer.create 10) in let prover_result = Call_provers.wait_on_call prover_call in - build_answer interpreted_task.predicate_kind prover_result + build_answer dataset_task.predicate prover_result diff --git a/src/dataset.ml b/src/dataset.ml new file mode 100644 index 0000000000000000000000000000000000000000..f0bd3383501d9dcc5dfdc94135c2b172635c525e --- /dev/null +++ b/src/dataset.ml @@ -0,0 +1,73 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +type eps = Constant.constant +type predicate = Correct | Robust of eps | CondRobust of eps +type 'a task = { model : 'a; predicate : predicate } + +let string_of_eps eps = Fmt.str "%a" Constant.print_def eps + +let find_predicate_ls env p = + let dataset_th = + Pmodule.read_module env [ "caisar" ] "DataSetClassification" + in + Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ] + +let failwith_unsupported_term t = + failwith (Fmt.str "Unsupported term in %a" Pretty.print_term t) + +let failwith_unsupported_ls ls = + failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls) + +let interpret env lookup task = + let term = Task.task_goal_fmla task in + match term.t_node with + | Term.Tapp (ls, { t_node = Tapp (ls_svm_apply, _); _ } :: _dataset :: tt) -> + let predicate = + match tt with + | [] -> + let correct_predicate = find_predicate_ls env "correct" in + if Term.ls_equal ls correct_predicate + then Correct + else failwith_unsupported_ls ls + | [ { t_node = Tconst e; _ } ] -> + let robust_predicate = find_predicate_ls env "robust" in + let cond_robust_predicate = find_predicate_ls env "cond_robust" in + if Term.ls_equal ls robust_predicate + then Robust e + else if Term.ls_equal ls cond_robust_predicate + then CondRobust e + else failwith_unsupported_ls ls + | _ -> failwith_unsupported_term term + in + let model = + match lookup ls_svm_apply with + | Some t -> t + | None -> invalid_arg "No model found in task" + in + { model; predicate } + | _ -> + (* No other term node is supported. *) + failwith_unsupported_term term diff --git a/src/dataset.mli b/src/dataset.mli new file mode 100644 index 0000000000000000000000000000000000000000..e80c775cc044b33e3fde7c1a682e2dea25e6f00a --- /dev/null +++ b/src/dataset.mli @@ -0,0 +1,30 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 + +type eps +type predicate = private Correct | Robust of eps | CondRobust of eps +type 'a task = private { model : 'a; predicate : predicate } + +val string_of_eps : eps -> string +val interpret : Env.env -> (Term.lsymbol -> 'a option) -> Task.task -> 'a task diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 302a893232ee5a50a230a7aeb536cd0055636c0a..1eb5c6da5c7800551b058762ad9acb7118a9397e 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -61,11 +61,9 @@ theory DataSetClassification a.length = b.length /\ forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps - predicate correct_at (m: model) (d: datum) (eps: t) = - forall x': features. - let (x, y) = d in - linfty_distance x x' eps -> - y = predict m x' + predicate correct_at (m: model) (d: datum) = + let (x, y) = d in + y = predict m x predicate robust_at (m: model) (d: datum) (eps: t) = forall x': features. @@ -74,16 +72,16 @@ theory DataSetClassification predict m x = predict m x' predicate cond_robust_at (m: model) (d: datum) (eps: t) = - correct_at m d eps /\ robust_at m d eps + correct_at m d /\ robust_at m d eps - predicate correct (m: model) (d: dataset) (eps: t) = - forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] eps + predicate correct (m: model) (d: dataset) = + forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] predicate robust (m: model) (d: dataset) (eps: t) = forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps predicate cond_robust (m: model) (d: dataset) (eps: t) = - correct m d eps /\ robust m d eps + correct m d /\ robust m d eps end theory SVM diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t index dba33cb4121b6639d475be22ce2d2f45456d63a9..120aab2d099ffa527f9d8aae7425911ec5eb3855 100644 --- a/tests/simple_ovo.t +++ b/tests/simple_ovo.t @@ -33,10 +33,10 @@ Test verify > use caisar.DataSetClassification > > goal G: robust svm_apply dataset (8.0:t) - > goal H: correct svm_apply dataset (8.0:t) + > goal H: correct svm_apply dataset > goal I: cond_robust svm_apply dataset (8.0:t) > end > EOF [caisar] Goal G: Valid (2/2) - [caisar] Goal H: Unknown () (1/2) + [caisar] Goal H: Invalid (1/2) [caisar] Goal I: Unknown () (0/2)