diff --git a/src/language.ml b/src/language.ml index 425458d6ce5a5b9ed334a5a9075a2f5798af27ad..e6b50ab64b1f058eb135812491abbf343e16ed5b 100644 --- a/src/language.ml +++ b/src/language.ml @@ -17,13 +17,19 @@ type nnshape = { filename : string; } +type svmshape = { + nb_inputs : int; + nb_classes : int; + filename : string; +} + let loaded_nets = Why3.Term.Hls.create 10 -let loaded_svms = Why3.Term.Hls.create 10 +let loaded_svms = Why3.Term.Hls.create 10 let lookup_loaded_nets = Why3.Term.Hls.find_opt loaded_nets let lookup_loaded_svms = Why3.Term.Hls.find_opt loaded_svms let register_nn_as_tuple nb_inputs nb_outputs filename env = - let net = Pmodule.read_module env [ "caisar" ] "NN" in +let net = Pmodule.read_module env [ "caisar" ] "NN" in let ioshape_input_type = Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) [] in @@ -45,28 +51,27 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env = in Wstdlib.Mstr.singleton "NNasTuple" (Pmodule.close_module th_uc) -let register_svm nb_inputs nb_outputs filename env = - let open Why3 in - let net = Pmodule.read_module env [ "caisar" ] "SVM" in - let asarray_input_type = - Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) [] +let register_svm_as_array nb_inputs nb_classes filename env = + let svm = Pmodule.read_module env [ "caisar" ] "SVM" in + let svm_type = + Ty.ty_app Theory.(ns_find_ts svm.mod_theory.th_export [ "svm" ]) [] in - let id_as_array = Ident.id_fresh "SVM" in + let id_as_array = Ident.id_fresh "SVMAsArray" in let th_uc = Pmodule.create_module env id_as_array in - let th_uc = Pmodule.use_export th_uc net in + let th_uc = Pmodule.use_export th_uc svm in let ls_svm_apply = Term.create_fsymbol (Ident.id_fresh "svm_apply") [] - (Ty.ty_func asarray_input_type asarray_input_type) + svm_type in Why3.Term.Hls.add loaded_svms ls_svm_apply - { filename; nb_inputs; nb_outputs; ty_data = asarray_input_type }; + { filename; nb_inputs; nb_classes }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_svm_apply)) in - Wstdlib.Mstr.singleton "SVM" (Pmodule.close_module th_uc) + Wstdlib.Mstr.singleton "SVMAsArray" (Pmodule.close_module th_uc) let nnet_parser env _ filename _ = let model = Nnet.parse filename in @@ -81,11 +86,10 @@ let onnx_parser env _ filename _ = | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env let ovo_parser env _ filename _ = - let open Why3 in let model = Ovo.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> register_svm model.n_inputs model.n_outputs filename env + | Ok model -> register_svm_as_array model.n_inputs model.n_outputs filename env let register_nnet_support () = Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language diff --git a/src/language.mli b/src/language.mli index 55d5051d7ec8a4fcdf592c43b8e9f8999bc7c2e1..515e94ca2b1718b2922b7620260ac5f0c3f52198 100644 --- a/src/language.mli +++ b/src/language.mli @@ -13,11 +13,13 @@ type nnshape = { filename : string; } +type svmshape = { nb_inputs : int; nb_classes : int; filename : string } + val lookup_loaded_nets : Term.lsymbol -> nnshape option (** @return the filename of a nnet Why3 representation. *) -val lookup_loaded_svms : Why3.Term.lsymbol -> nnshape option -(** @return the ioshape of a svm Why3 representation. *) +val lookup_loaded_svms : Why3.Term.lsymbol -> svmshape option +(** @return the svmshape of a svm Why3 representation. *) val register_nnet_support : unit -> unit (** Register NNet parser. *) diff --git a/src/verification.ml b/src/verification.ml index 0aec2d2563375b3f37b79da8e1cb9a0760e76a41..0f2c6041345b7037fb77f43b8c21bd0a9ca1e8c6 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -56,7 +56,6 @@ let combine_prover_answers answers = | _ -> acc) let answer_saver limit config task env prover dataset_csv = - let open Why3 in let handle_task_saver task env dataset_csv command = let dataset_filename = match dataset_csv with @@ -144,7 +143,6 @@ let answer_saver limit config task env prover dataset_csv = let answer_generic limit config task driver (prover : Why3.Whyconf.config_prover) = - let open Why3 in let task_prepared = Driver.prepare_task driver task in let tasks = (* We make [tasks] as a list (ie, conjunction) of disjunctions. *) @@ -175,7 +173,6 @@ let answer_generic limit config task driver let call_prover ~limit config (prover : Why3.Whyconf.config_prover) driver env dataset_csv task = - let open Why3 in let prover_answer = if String.equal prover.prover.prover_name "SAVer" then answer_saver limit config task env prover dataset_csv @@ -186,7 +183,6 @@ let call_prover ~limit config (prover : Why3.Whyconf.config_prover) driver env let verify ?(debug = false) format loadpath ?memlimit ?timeout prover ?dataset_csv file = - let open Why3 in if debug then Debug.set_flag (Debug.lookup_flag "call_prover"); let env, config = create_env loadpath in let steplimit = None in diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 892bd676efdfa5b6602b68e8636b341865220e97..3f3d1f7f62201aae8c36bb41e124a26b056a24c3 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -5,7 +5,16 @@ end theory SVM use ieee_float.Float64 - use array.Array - type input_type = array t - predicate robust_to (input_type -> input_type) (input_type) (t) + use int.Int + type input_type = int -> t + type output_type = int + type svm = {app : input_type -> output_type; num_input: int; num_classes: int} + + predicate dist_linf (a: input_type) (b: input_type) (eps:t) (n: int)= + forall i. 0 <= i < n -> + .- eps .< a i .- b i .< eps + + predicate robust_to (svm: svm) (a: input_type) (eps: t) = + forall b. dist_linf a b eps svm.num_input -> svm.app a = svm.app b + end diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t index ad3aaafbe2d82f895c32c8d1cf7d729baf7f098e..f9ba617809da6980579172c44d60c112d6d23048 100644 --- a/tests/simple_ovo.t +++ b/tests/simple_ovo.t @@ -22,13 +22,13 @@ Test verify $ caisar verify -L . --format whyml --prover=SAVer --dataset-csv=test_data.csv - 2>&1 <<EOF | sed 's/\/tmp\/[a-z0-9./]*/$TMPFILE/' > theory T - > use TestSVM.SVM + > use TestSVM.SVMAsArray > use ieee_float.Float64 > use caisar.IOShape > use caisar.SVM > > goal G: forall a : input_type. - > robust_to SVM.svm_apply a (8.0:t) + > robust_to svm_apply a (8.0:t) > end > EOF <autodetect>0 prover(s) added