From 558f5d173dbaa5f390594e1759942cce745e4116 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 5 Oct 2022 18:42:55 +0200 Subject: [PATCH] [dataset][SAVer] Rework types and SAVer's call prover API. --- src/SAVer.ml | 32 ++++++++++++++++++-------------- src/SAVer.mli | 6 ++---- src/dataset.ml | 29 ++++++++++++++--------------- src/dataset.mli | 12 +++++++++--- src/language.mli | 4 ++-- src/verification.ml | 10 +++++++++- 6 files changed, 54 insertions(+), 39 deletions(-) diff --git a/src/SAVer.ml b/src/SAVer.ml index 64958a1..b636aa9 100644 --- a/src/SAVer.ml +++ b/src/SAVer.ml @@ -75,11 +75,13 @@ let re_group_number_of_predicate = function | Robust _ -> 3 | CondRobust _ -> 4 -let negative_prover_answer_of_predicate = function - | Dataset.Correct -> Call_provers.Invalid +let negative_answer = function + | Dataset.Correct -> + (* SAVer is complete wrt correct predicate. *) + Call_provers.Invalid | Robust _ | CondRobust _ -> Call_provers.Unknown "" -let build_answer predicate prover_result = +let build_answer predicate_kind prover_result = match prover_result.Call_provers.pr_answer with | Call_provers.HighFailure -> ( let pr_output = prover_result.pr_output in @@ -89,13 +91,13 @@ let build_answer predicate prover_result = Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size) in let nb_proved = - let re_group_number = re_group_number_of_predicate predicate in + let re_group_number = re_group_number_of_predicate predicate_kind in Int.of_string (Re__Core.Group.get re_group re_group_number) in let prover_answer = if nb_total = nb_proved then Call_provers.Valid - else negative_prover_answer_of_predicate predicate + else negative_answer predicate_kind in { prover_answer; nb_total; nb_proved } | None -> failwith "Cannot interpret the output provided by SAVer") @@ -104,15 +106,17 @@ let build_answer predicate prover_result = anything in SAVer's driver. *) assert false -let call limit config env config_prover ~dataset task = - let dataset_task = Dataset.interpret env Language.lookup_loaded_svms task in - let svm = dataset_task.model.filename in - let eps = - match dataset_task.predicate with - | Dataset.Correct -> None - | Robust e | CondRobust e -> Some e +let call_prover limit config config_prover predicate = + let command = + let svm = predicate.Dataset.model.Language.filename in + let dataset = predicate.dataset in + let eps = + match predicate.kind with + | Dataset.Correct -> None + | Robust e | CondRobust e -> Some e + in + build_command config_prover svm dataset eps in - let command = build_command config_prover svm dataset eps in let prover_call = let res_parser = { @@ -130,4 +134,4 @@ let call limit config env config_prover ~dataset task = ~printing_info:Printer.default_printing_info (Buffer.create 10) in let prover_result = Call_provers.wait_on_call prover_call in - build_answer dataset_task.predicate prover_result + build_answer predicate.kind prover_result diff --git a/src/SAVer.mli b/src/SAVer.mli index 2119b52..489d223 100644 --- a/src/SAVer.mli +++ b/src/SAVer.mli @@ -28,11 +28,9 @@ type answer = { nb_proved : int; (** Number of data points verifying the property. *) } -val call : +val call_prover : Call_provers.resource_limit -> Whyconf.main -> - Env.env -> Whyconf.config_prover -> - dataset:string -> - Task.task -> + (Language.svm_shape, string) Dataset.predicate -> answer diff --git a/src/dataset.ml b/src/dataset.ml index f0bd338..3443b41 100644 --- a/src/dataset.ml +++ b/src/dataset.ml @@ -24,16 +24,14 @@ open Why3 open Base type eps = Constant.constant -type predicate = Correct | Robust of eps | CondRobust of eps -type 'a task = { model : 'a; predicate : predicate } +type kind = Correct | Robust of eps | CondRobust of eps +type ('a, 'b) predicate = { model : 'a; dataset : 'b; kind : kind } let string_of_eps eps = Fmt.str "%a" Constant.print_def eps let find_predicate_ls env p = - let dataset_th = - Pmodule.read_module env [ "caisar" ] "DataSetClassification" - in - Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ] + let th = Pmodule.read_module env [ "caisar" ] "DataSetClassification" in + Theory.ns_find_ls th.mod_theory.th_export [ p ] let failwith_unsupported_term t = failwith (Fmt.str "Unsupported term in %a" Pretty.print_term t) @@ -41,11 +39,15 @@ let failwith_unsupported_term t = let failwith_unsupported_ls ls = failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls) -let interpret env lookup task = +let interpret_predicate env ~on_model ~on_dataset task = let term = Task.task_goal_fmla task in match term.t_node with - | Term.Tapp (ls, { t_node = Tapp (ls_svm_apply, _); _ } :: _dataset :: tt) -> - let predicate = + | Term.Tapp + ( ls, + { t_node = Tapp (ls_svm_apply, _); _ } (* model *) + :: { t_node = Tapp (dataset, _); _ } (* dataset *) + :: tt ) -> + let kind = match tt with | [] -> let correct_predicate = find_predicate_ls env "correct" in @@ -62,12 +64,9 @@ let interpret env lookup task = else failwith_unsupported_ls ls | _ -> failwith_unsupported_term term in - let model = - match lookup ls_svm_apply with - | Some t -> t - | None -> invalid_arg "No model found in task" - in - { model; predicate } + let dataset = on_dataset dataset in + let model = on_model ls_svm_apply in + { model; dataset; kind } | _ -> (* No other term node is supported. *) failwith_unsupported_term term diff --git a/src/dataset.mli b/src/dataset.mli index e80c775..c05d438 100644 --- a/src/dataset.mli +++ b/src/dataset.mli @@ -23,8 +23,14 @@ open Why3 type eps -type predicate = private Correct | Robust of eps | CondRobust of eps -type 'a task = private { model : 'a; predicate : predicate } +type kind = private Correct | Robust of eps | CondRobust of eps +type ('a, 'b) predicate = private { model : 'a; dataset : 'b; kind : kind } val string_of_eps : eps -> string -val interpret : Env.env -> (Term.lsymbol -> 'a option) -> Task.task -> 'a task + +val interpret_predicate : + Env.env -> + on_model:(Term.lsymbol -> 'a) -> + on_dataset:(Term.lsymbol -> 'b) -> + Task.task -> + ('a, 'b) predicate diff --git a/src/language.mli b/src/language.mli index cd3329a..4d0a1ca 100644 --- a/src/language.mli +++ b/src/language.mli @@ -32,10 +32,10 @@ type nn_shape = { type svm_shape = { nb_inputs : int; nb_classes : int; filename : string } val lookup_loaded_nets : Term.lsymbol -> nn_shape option -(** @return the shape of a nnet Why3 representation. *) +(** @return the shape of a NN given its Why3 representation. *) val lookup_loaded_svms : Term.lsymbol -> svm_shape option -(** @return the shape of a svm Why3 representation. *) +(** @return the shape of a SVM given its Why3 representation. *) val register_nnet_support : unit -> unit (** Register NNet parser. *) diff --git a/src/verification.ml b/src/verification.ml index 85e7b26..8b7fef1 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -74,7 +74,15 @@ let answer_saver limit config env config_prover dataset_csv task = | None -> invalid_arg "No dataset provided for SAVer" | Some filename -> filename in - let answer = SAVer.call limit config env config_prover ~dataset task in + let dataset_predicate = + let on_model ls = + let message = Fmt.str "No SVM model found as %a" Pretty.print_ls ls in + Option.value_exn ~message (Language.lookup_loaded_svms ls) + in + let on_dataset _ls = dataset in + Dataset.interpret_predicate env ~on_model ~on_dataset task + in + let answer = SAVer.call_prover limit config config_prover dataset_predicate in match answer.prover_answer with | Call_provers.Unknown "" -> let additional_info = Fmt.str "%d/%d" answer.nb_proved answer.nb_total in -- GitLab