From 80a03afad1c1483c3eba67f6b23a09c9ee177531 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Fri, 2 Sep 2022 15:03:43 +0200 Subject: [PATCH] [stdlib] Rework stdlib to have a DataSet theory. --- src/verification.ml | 92 +++++++++++++++++++++------------------------ stdlib/caisar.mlw | 50 ++++++++++++++++++------ tests/simple_ovo.t | 4 +- 3 files changed, 83 insertions(+), 63 deletions(-) diff --git a/src/verification.ml b/src/verification.ml index b08b2d1..c455308 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -80,38 +80,33 @@ let answer_saver limit config env config_prover dataset_csv task = | None -> invalid_arg "No dataset provided for SAVer" | Some s -> s in + let robust_predicate = + let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in + Theory.ns_find_ls dataset_th.mod_theory.th_export [ "robust" ] + in let goal = Task.task_goal_fmla task in let eps, svm_filename = match goal.t_node with - | Tquant (Tforall, b) -> ( - let _, _, pred = Term.t_open_quant b in - let svm_t = Pmodule.read_module env [ "caisar" ] "SVM" in - let robust_to_predicate = - let open Theory in - ns_find_ls svm_t.mod_theory.th_export [ "robust_to" ] - in - match pred.t_node with - | Term.Tapp - ( ls, - [ - { t_node = Tapp (svm_app_sym, _); _ }; - _; - { t_node = Tconst e; _ }; - ] ) -> - if Term.ls_equal ls robust_to_predicate - then - let eps = Fmt.str "%a" Constant.print_def e in - let svm_filename = - match Language.lookup_loaded_svms svm_app_sym with - | Some t -> t.filename - | None -> invalid_arg "No SVM model found in task" - in - (eps, svm_filename) - else failwith "Wrong predicate found" - | _ -> - (* no other predicate than robust_to is supported *) - failwith "Unsupported predicate by SAVer") - | _ -> failwith "Unsupported predicate by SAVer" + | Term.Tapp + ( ls, + [ + { t_node = Tapp (ls_svm_apply, _); _ }; + _dataset; + { t_node = Tconst e; _ }; + ] ) -> + if Term.ls_equal ls robust_predicate + then + let eps = Fmt.str "%a" Constant.print_def e in + let svm_filename = + match Language.lookup_loaded_svms ls_svm_apply with + | Some t -> t.filename + | None -> invalid_arg "No SVM model found in task" + in + (eps, svm_filename) + else failwith "Wrong predicate found" + | _ -> + (* no other predicate than robust_to is supported *) + failwith "Unsupported predicate by SAVer" in let svm_file = Unix.realpath svm_filename in let dataset_file = Unix.realpath dataset_filename in @@ -141,27 +136,24 @@ let answer_saver limit config env config_prover dataset_csv task = ~printing_info:Printer.default_printing_info (Buffer.create 10) in let prover_result = Call_provers.wait_on_call prover_call in - let answer = - match prover_result.pr_answer with - | Call_provers.HighFailure -> ( - let pr_output = prover_result.pr_output in - let matcher = - Re__Pcre.regexp - "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d" - in - match Re__Core.exec_opt matcher pr_output with - | Some g -> - if Int.of_string (Re__Core.Group.get g 1) - = Int.of_string (Re__Core.Group.get g 2) - then Call_provers.Valid - else Call_provers.Invalid - | None -> Call_provers.HighFailure) - | _ -> - (* Any other answer than HighFailure should never happen as we do not - define anything in SAVer's driver. *) - assert false - in - answer + match prover_result.pr_answer with + | Call_provers.HighFailure -> ( + let pr_output = prover_result.pr_output in + let matcher = + Re__Pcre.regexp + "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d" + in + match Re__Core.exec_opt matcher pr_output with + | Some g -> + if Int.of_string (Re__Core.Group.get g 1) + = Int.of_string (Re__Core.Group.get g 2) + then Call_provers.Valid + else Call_provers.Invalid + | None -> Call_provers.HighFailure) + | _ -> + (* Any other answer than HighFailure should never happen as we do not define + anything in SAVer's driver. *) + assert false let answer_generic limit config prover config_prover driver task = let task_prepared = Driver.prepare_task driver task in diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 5f5f3d3..675b1ab 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -25,22 +25,50 @@ theory NN type input_type = t end -theory SVM +theory Model use ieee_float.Float64 use int.Int + use array.Array + + type model = { + nb_inputs: int; + nb_outputs: int; + } - type input_type = int -> t - type output_type = int + function predict: model -> array t -> int +end + +theory DataSet + use ieee_float.Float64 + use int.Int + use array.Array + use Model - type svm = { - apply : input_type -> output_type; - nb_inputs : int; - nb_classes : int; + type features = array t + type class = int + + type datum = (features, class) + + type dataset = { + nb_features: int; + nb_classes: int; + data: array datum } - predicate linfty_distance (a: input_type) (b: input_type) (eps:t) (n: int) = - forall i. 0 <= i < n -> .- eps .< a i .- b i .< eps + constant dataset: dataset + + predicate linfty_distance (n: int) (a: features) (b: features) (eps: t) = + forall i: int. 0 <= i < n -> .- eps .< a[i] .- b[i] .< eps - predicate robust_to (svm: svm) (a: input_type) (eps: t) = - forall b. linfty_distance a b eps svm.nb_inputs -> svm.apply a = svm.apply b + predicate robust (m: model) (d: dataset) (eps: t) = + forall i: int. 0 <= i < d.data.length -> + forall x': features. + let (x, _) = d.data[i] in + linfty_distance d.nb_features x x' eps -> + predict m x = predict m x' +end + +theory SVM + use Model + type svm = model end diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t index 34a4608..1e02da7 100644 --- a/tests/simple_ovo.t +++ b/tests/simple_ovo.t @@ -29,9 +29,9 @@ Test verify > theory T > use TestSVM.SVMasArray > use ieee_float.Float64 - > use caisar.SVM + > use caisar.DataSet > - > goal G: forall a : input_type. robust_to svm_apply a (8.0:t) + > goal G: robust svm_apply dataset (8.0:t) > end > EOF [caisar] Goal G: High failure -- GitLab