diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw new file mode 100644 index 0000000000000000000000000000000000000000..9a27a91670a4041ac3b406487161654f168bd325 --- /dev/null +++ b/stdlib/interpretation.mlw @@ -0,0 +1,93 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +theory Vector + use int.Int + + type vector 'a + + function ([]) (v: vector 'a) (i: int) : 'a + function length (v: vector 'a) : int + + function map (v: vector 'a) (f: 'a -> 'b) : vector 'b + function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c + + function fold (v: vector 'a) (acc: 'b) (f: 'b -> 'a -> 'b) : 'b + function fold2 (v1: vector 'a) (v2: vector 'b) (acc: 'c) (f: 'c -> 'a -> 'b -> 'c) : 'c + + scope L + predicate forall_ (v: vector 'a) (f: 'a -> bool) = + fold v True (fun acc e -> acc /\ f e) + + predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) = + length(v1) = length(v2) -> fold2 v1 v2 True (fun acc e1 e2 -> acc /\ f e1 e2) + + function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b = + map v f + + function foreach2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c = + map2 v1 v2 f + end +end + +theory Tensor + use int.Int + use Vector + + type tensor 'a + type index = vector int + + function (#) (t: tensor 'a) (v: vector int) : 'a + function (-) (t1: tensor 'a) (t2: tensor 'a) : tensor 'a + + predicate equal_shape (t1: tensor 'a) (t2: tensor 'b) + predicate valid_index (t: tensor 'a) (v: index) +end + +theory Classifier + use Vector + use Tensor + + type classifier + type kind = ONNX | NNet | OVO + + function read_classifier (f: string) (k: kind) : classifier + function (@@) (c: classifier) (t: tensor 'a) : vector 'a +end + +theory Dataset + use Vector + use Tensor + + type dataset 'a 'b = vector ('a, 'b) + type kind = CSV + + function read_dataset (f: string) (k: kind) : dataset 'a 'b + + scope L + predicate forall_ (d: dataset 'a 'b) (f: 'a -> 'b -> bool) = + Vector.L.forall_ d (fun e -> let a, b = e in f a b) + + function foreach (d: dataset 'a 'b) (f: 'a -> 'b -> 'c) : vector 'c = + Vector.L.foreach d (fun e -> let a, b = e in f a b) + end +end