From b58430d80c63ad5d8d40a9d9beaa9149d979851b Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Tue, 20 Feb 2024 13:08:30 +0100 Subject: [PATCH] [stdlib] Generalize theory for datasets. --- examples/mnist/mnist.why | 8 ++++---- src/interpretation.ml | 6 +++--- src/language.ml | 6 ++++-- stdlib/dataset.mlw | 35 ++++++++++++++++++++++++----------- stdlib/robust.mlw | 8 ++++---- tests/dataset.t | 8 ++++---- 6 files changed, 43 insertions(+), 28 deletions(-) diff --git a/examples/mnist/mnist.why b/examples/mnist/mnist.why index bb17f9a..707dbfc 100644 --- a/examples/mnist/mnist.why +++ b/examples/mnist/mnist.why @@ -2,18 +2,18 @@ theory MNIST use ieee_float.Float64 use int.Int use nn.NeuralNetwork - use dataset.DatasetCSV - use robust.RobustDatasetCSV + use dataset.CSV + use robust.RobustCSV constant min_label : label_ = 0 constant max_label : label_ = 9 predicate valid_label (l: label_) = - RobustDatasetCSV.valid_label min_label max_label l + RobustCSV.valid_label min_label max_label l goal robustness: let nn = read_neural_network "nets/MNIST_256_2.onnx" ONNX in - let dataset = read_dataset_csv "csv/mnist_test.csv" in + let dataset = read_dataset "csv/mnist_test.csv" in let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in robust valid_label nn dataset eps end diff --git a/src/interpretation.ml b/src/interpretation.ml index 31e2ca9..97f64b4 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -317,7 +317,7 @@ let caisar_builtins : caisar_env CRE.built_in_theories list = in (* Dataset *) - let read_dataset_csv : _ CRE.builtin = + let read_dataset : _ CRE.builtin = fun engine ls vl ty -> match vl with | [ Term { t_node = Tconst (ConstStr dataset); _ } ] -> @@ -394,10 +394,10 @@ let caisar_builtins : caisar_env CRE.built_in_theories list = ([ Ident.op_infix "@@" ], None, apply_neural_network); ] ); ( [ "dataset" ], - "DatasetCSV", + "CSV", [], [ - ([ "read_dataset_csv" ], None, read_dataset_csv); + ([ "read_dataset" ], None, read_dataset); ([ "min_label" ], None, min_label); ([ "max_label" ], None, max_label); ] ); diff --git a/src/language.ml b/src/language.ml index 971f59c..240dc15 100644 --- a/src/language.ml +++ b/src/language.ml @@ -270,8 +270,10 @@ let datasets = Term.Hls.create 10 let fresh_dataset_csv_ls env name = let ty = let ty_feature = ty_float64_t env in - let th = Env.read_theory env [ "dataset" ] "DatasetCSV" in - Ty.ty_app (Theory.ns_find_ts th.th_export [ "dataset" ]) [ ty_feature ] + let th = Env.read_theory env [ "dataset" ] "CSV" in + Ty.ty_app + (Theory.ns_find_ts th.th_export [ "dataset" ]) + [ Ty.ty_int; ty_feature ] in let id = Ident.id_fresh name in Term.create_fsymbol id [] ty diff --git a/stdlib/dataset.mlw b/stdlib/dataset.mlw index 050a3a6..acbc0d7 100644 --- a/stdlib/dataset.mlw +++ b/stdlib/dataset.mlw @@ -22,6 +22,26 @@ (** {1 Datasets} *) +(** {2 Generic Datasets} *) + +theory Dataset + + use vector.Vector + + type a 'a + type b 'b + type dataset 'a 'b = vector (a 'a, b 'b) + + function read_dataset (f: string) : dataset 'a 'b + + predicate forall_ (d: dataset 'a 'b) (f: a 'a -> b 'b -> bool) = + Vector.forall_ d (fun e -> let a, b = e in f a b) + + function foreach (d: dataset 'a 'b) (f: a 'a -> b 'b -> 'c) : vector 'c = + Vector.foreach d (fun e -> let a, b = e in f a b) + +end + (** {2 CSV Datasets} A dataset in CSV format is such that each element is given as: @@ -30,24 +50,17 @@ A dataset in CSV format is such that each element is given as: *) -theory DatasetCSV +theory CSV use int.Int use vector.Vector type label_ = int type features 'a = vector 'a - type dataset 'a = vector (label_, features 'a) - - function read_dataset_csv (f: string) : dataset 'a - - function min_label (d: dataset 'a) : label_ - function max_label (d: dataset 'a) : label_ - predicate forall_ (d: dataset 'a) (f: label_ -> features 'a -> bool) = - Vector.forall_ d (fun e -> let i, a = e in f i a) + clone export Dataset with type a 'a = label_, type b 'a = features 'a - function foreach (d: dataset 'a) (f: label_ -> features 'a -> 'b) : vector 'b = - Vector.foreach d (fun e -> let i, a = e in f i a) + function min_label (d: dataset label_ 'a) : label_ + function max_label (d: dataset label_ 'a) : label_ end diff --git a/stdlib/robust.mlw b/stdlib/robust.mlw index 4c7853d..ed49a73 100644 --- a/stdlib/robust.mlw +++ b/stdlib/robust.mlw @@ -24,13 +24,13 @@ (** {2 Robustness of CSV Datasets} *) -theory RobustDatasetCSV +theory RobustCSV use ieee_float.Float64 use int.Int use vector.Vector use nn.NeuralNetwork - use dataset.DatasetCSV + use dataset.CSV type elt = features t @@ -58,7 +58,7 @@ theory RobustDatasetCSV advises valid_label nn perturbed_elt l predicate robust (valid_label: label_ -> bool) - (nn: nn) (d: dataset t) (eps: t) = - DatasetCSV.forall_ d (robust_around valid_label nn eps) + (nn: nn) (d: dataset label_ t) (eps: t) = + CSV.forall_ d (robust_around valid_label nn eps) end \ No newline at end of file diff --git a/tests/dataset.t b/tests/dataset.t index af53aee..95d4ced 100644 --- a/tests/dataset.t +++ b/tests/dataset.t @@ -17,18 +17,18 @@ Test verify on dataset > use ieee_float.Float64 > use int.Int > use nn.NeuralNetwork - > use dataset.DatasetCSV - > use robust.RobustDatasetCSV + > use dataset.CSV + > use robust.RobustCSV > > constant min_label: label_ = 0 > constant max_label: label_ = 4 > > predicate valid_label (l: label_) = - > RobustDatasetCSV.valid_label min_label max_label l + > RobustCSV.valid_label min_label max_label l > > goal H: > let nn = read_neural_network "TestNetwork.nnet" NNet in - > let dataset = read_dataset_csv "dataset.csv" in + > let dataset = read_dataset "dataset.csv" in > let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in > robust valid_label nn dataset eps > end -- GitLab