diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index f14bf99496fcbce5dac573fcf33fd47132bb03be..fd0eacf17c2c6837767a315e6c2c1d2ab21f0cf8 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -27,24 +27,15 @@ theory DatasetClassification type features = array t type label_ = int - - type datum = (features, label_) - - type dataset = { - nb_features: int; - nb_labels: int; - data: array datum - } + type record = (features, label_) + type dataset = array record constant dataset: dataset function min_max_scale (clip: bool) (min: t) (max: t) (d: dataset): dataset function z_norm (mean: t) (std_dev: t) (d: dataset): dataset - type model = { - nb_inputs: int; - nb_outputs: int; - } + type model function predict: model -> features -> label_ end @@ -59,24 +50,24 @@ theory DatasetClassificationProps a.length = b.length /\ forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps - predicate correct_at (m: model) (d: datum) = + predicate correct_at (m: model) (d: record) = let (x, y) = d in - y = predict m x + predict m x = y - predicate robust_at (m: model) (d: datum) (eps: t) = + predicate robust_at (m: model) (d: record) (eps: t) = forall x': features. let (x, _) = d in linfty_distance x x' eps -> predict m x = predict m x' - predicate cond_robust_at (m: model) (d: datum) (eps: t) = + predicate cond_robust_at (m: model) (d: record) (eps: t) = correct_at m d /\ robust_at m d eps predicate correct (m: model) (d: dataset) = - forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] + forall i: int. 0 <= i < d.length -> correct_at m d[i] predicate robust (m: model) (d: dataset) (eps: t) = - forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps + forall i: int. 0 <= i < d.length -> robust_at m d[i] eps predicate cond_robust (m: model) (d: dataset) (eps: t) = correct m d /\ robust m d eps