diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw
index f14bf99496fcbce5dac573fcf33fd47132bb03be..fd0eacf17c2c6837767a315e6c2c1d2ab21f0cf8 100644
--- a/stdlib/caisar.mlw
+++ b/stdlib/caisar.mlw
@@ -27,24 +27,15 @@ theory DatasetClassification
 
   type features = array t
   type label_ = int
-
-  type datum = (features, label_)
-
-  type dataset = {
-    nb_features: int;
-    nb_labels: int;
-    data: array datum
-  }
+  type record = (features, label_)
+  type dataset = array record
 
   constant dataset: dataset
 
   function min_max_scale (clip: bool) (min: t) (max: t) (d: dataset): dataset
   function z_norm (mean: t) (std_dev: t) (d: dataset): dataset
 
-  type model = {
-    nb_inputs: int;
-    nb_outputs: int;
-  }
+  type model
 
   function predict: model -> features -> label_
 end
@@ -59,24 +50,24 @@ theory DatasetClassificationProps
     a.length = b.length /\
     forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps
 
-  predicate correct_at (m: model) (d: datum) =
+  predicate correct_at (m: model) (d: record) =
     let (x, y) = d in
-    y = predict m x
+    predict m x = y
 
-  predicate robust_at (m: model) (d: datum) (eps: t) =
+  predicate robust_at (m: model) (d: record) (eps: t) =
     forall x': features.
       let (x, _) = d in
       linfty_distance x x' eps ->
       predict m x = predict m x'
 
-  predicate cond_robust_at (m: model) (d: datum) (eps: t) =
+  predicate cond_robust_at (m: model) (d: record) (eps: t) =
     correct_at m d /\ robust_at m d eps
 
   predicate correct (m: model) (d: dataset) =
-    forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i]
+    forall i: int. 0 <= i < d.length -> correct_at m d[i]
 
   predicate robust (m: model) (d: dataset) (eps: t) =
-    forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps
+    forall i: int. 0 <= i < d.length -> robust_at m d[i] eps
 
   predicate cond_robust (m: model) (d: dataset) (eps: t) =
     correct m d /\ robust m d eps