From b58430d80c63ad5d8d40a9d9beaa9149d979851b Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Tue, 20 Feb 2024 13:08:30 +0100
Subject: [PATCH] [stdlib] Generalize theory for datasets.

---
 examples/mnist/mnist.why |  8 ++++----
 src/interpretation.ml    |  6 +++---
 src/language.ml          |  6 ++++--
 stdlib/dataset.mlw       | 35 ++++++++++++++++++++++++-----------
 stdlib/robust.mlw        |  8 ++++----
 tests/dataset.t          |  8 ++++----
 6 files changed, 43 insertions(+), 28 deletions(-)

diff --git a/examples/mnist/mnist.why b/examples/mnist/mnist.why
index bb17f9a..707dbfc 100644
--- a/examples/mnist/mnist.why
+++ b/examples/mnist/mnist.why
@@ -2,18 +2,18 @@ theory MNIST
   use ieee_float.Float64
   use int.Int
   use nn.NeuralNetwork
-  use dataset.DatasetCSV
-  use robust.RobustDatasetCSV
+  use dataset.CSV
+  use robust.RobustCSV
 
   constant min_label : label_ = 0
   constant max_label : label_ = 9
 
   predicate valid_label (l: label_) =
-    RobustDatasetCSV.valid_label min_label max_label l
+    RobustCSV.valid_label min_label max_label l
 
   goal robustness:
     let nn = read_neural_network "nets/MNIST_256_2.onnx" ONNX in
-    let dataset = read_dataset_csv "csv/mnist_test.csv" in
+    let dataset = read_dataset "csv/mnist_test.csv" in
     let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in
     robust valid_label nn dataset eps
 end
diff --git a/src/interpretation.ml b/src/interpretation.ml
index 31e2ca9..97f64b4 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -317,7 +317,7 @@ let caisar_builtins : caisar_env CRE.built_in_theories list =
   in
 
   (* Dataset *)
-  let read_dataset_csv : _ CRE.builtin =
+  let read_dataset : _ CRE.builtin =
    fun engine ls vl ty ->
     match vl with
     | [ Term { t_node = Tconst (ConstStr dataset); _ } ] ->
@@ -394,10 +394,10 @@ let caisar_builtins : caisar_env CRE.built_in_theories list =
         ([ Ident.op_infix "@@" ], None, apply_neural_network);
       ] );
     ( [ "dataset" ],
-      "DatasetCSV",
+      "CSV",
       [],
       [
-        ([ "read_dataset_csv" ], None, read_dataset_csv);
+        ([ "read_dataset" ], None, read_dataset);
         ([ "min_label" ], None, min_label);
         ([ "max_label" ], None, max_label);
       ] );
diff --git a/src/language.ml b/src/language.ml
index 971f59c..240dc15 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -270,8 +270,10 @@ let datasets = Term.Hls.create 10
 let fresh_dataset_csv_ls env name =
   let ty =
     let ty_feature = ty_float64_t env in
-    let th = Env.read_theory env [ "dataset" ] "DatasetCSV" in
-    Ty.ty_app (Theory.ns_find_ts th.th_export [ "dataset" ]) [ ty_feature ]
+    let th = Env.read_theory env [ "dataset" ] "CSV" in
+    Ty.ty_app
+      (Theory.ns_find_ts th.th_export [ "dataset" ])
+      [ Ty.ty_int; ty_feature ]
   in
   let id = Ident.id_fresh name in
   Term.create_fsymbol id [] ty
diff --git a/stdlib/dataset.mlw b/stdlib/dataset.mlw
index 050a3a6..acbc0d7 100644
--- a/stdlib/dataset.mlw
+++ b/stdlib/dataset.mlw
@@ -22,6 +22,26 @@
 
 (** {1 Datasets} *)
 
+(** {2 Generic Datasets} *)
+
+theory Dataset
+
+  use vector.Vector
+
+  type a 'a
+  type b 'b
+  type dataset 'a 'b = vector (a 'a, b 'b)
+
+  function read_dataset (f: string) : dataset 'a 'b
+
+  predicate forall_ (d: dataset 'a 'b) (f: a 'a -> b 'b -> bool) =
+    Vector.forall_ d (fun e -> let a, b = e in f a b)
+
+  function foreach (d: dataset 'a 'b) (f: a 'a -> b 'b -> 'c) : vector 'c =
+    Vector.foreach d (fun e -> let a, b = e in f a b)
+
+end
+
 (** {2 CSV Datasets}
 
 A dataset in CSV format is such that each element is given as:
@@ -30,24 +50,17 @@ A dataset in CSV format is such that each element is given as:
 
 *)
 
-theory DatasetCSV
+theory CSV
 
   use int.Int
   use vector.Vector
 
   type label_ = int
   type features 'a = vector 'a
-  type dataset 'a = vector (label_, features 'a)
-
-  function read_dataset_csv (f: string) : dataset 'a
-
-  function min_label (d: dataset 'a) : label_
-  function max_label (d: dataset 'a) : label_
 
-  predicate forall_ (d: dataset 'a) (f: label_ -> features 'a -> bool) =
-    Vector.forall_ d (fun e -> let i, a = e in f i a)
+  clone export Dataset with type a 'a = label_, type b 'a = features 'a
 
-  function foreach (d: dataset 'a) (f: label_ -> features 'a -> 'b) : vector 'b =
-    Vector.foreach d (fun e -> let i, a = e in f i a)
+  function min_label (d: dataset label_ 'a) : label_
+  function max_label (d: dataset label_ 'a) : label_
 
 end
diff --git a/stdlib/robust.mlw b/stdlib/robust.mlw
index 4c7853d..ed49a73 100644
--- a/stdlib/robust.mlw
+++ b/stdlib/robust.mlw
@@ -24,13 +24,13 @@
 
 (** {2 Robustness of CSV Datasets} *)
 
-theory RobustDatasetCSV
+theory RobustCSV
 
    use ieee_float.Float64
    use int.Int
    use vector.Vector
    use nn.NeuralNetwork
-   use dataset.DatasetCSV
+   use dataset.CSV
 
    type elt = features t
 
@@ -58,7 +58,7 @@ theory RobustDatasetCSV
        advises valid_label nn perturbed_elt l
 
    predicate robust (valid_label: label_ -> bool)
-                    (nn: nn) (d: dataset t) (eps: t) =
-     DatasetCSV.forall_ d (robust_around valid_label nn eps)
+                    (nn: nn) (d: dataset label_ t) (eps: t) =
+     CSV.forall_ d (robust_around valid_label nn eps)
 
 end
\ No newline at end of file
diff --git a/tests/dataset.t b/tests/dataset.t
index af53aee..95d4ced 100644
--- a/tests/dataset.t
+++ b/tests/dataset.t
@@ -17,18 +17,18 @@ Test verify on dataset
   >   use ieee_float.Float64
   >   use int.Int
   >   use nn.NeuralNetwork
-  >   use dataset.DatasetCSV
-  >   use robust.RobustDatasetCSV
+  >   use dataset.CSV
+  >   use robust.RobustCSV
   > 
   >   constant min_label: label_ = 0
   >   constant max_label: label_ = 4
   > 
   >   predicate valid_label (l: label_) =
-  >     RobustDatasetCSV.valid_label min_label max_label l
+  >     RobustCSV.valid_label min_label max_label l
   > 
   >   goal H:
   >     let nn = read_neural_network "TestNetwork.nnet" NNet in
-  >     let dataset = read_dataset_csv "dataset.csv" in
+  >     let dataset = read_dataset "dataset.csv" in
   >     let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in
   >     robust valid_label nn dataset eps
   > end
-- 
GitLab