Skip to content
Snippets Groups Projects
Commit 20feb4da authored by Michele Alberti's avatar Michele Alberti
Browse files

[interpretation] wip.

parent 8d2fb49e
No related branches found
No related tags found
No related merge requests found
......@@ -24,22 +24,39 @@ module CRE = Reduction_engine (* Caisar Reduction Engine *)
open Why3
open Base
type dataset = { dataset : string } [@@deriving show]
type caisar_op = Dataset of dataset [@@deriving show]
type dataset = CSV of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
[@@deriving show]
type classifier = string [@@deriving show]
type caisar_op =
| Dataset of dataset
| Data of string list
| Classifier of classifier
| ClassifierApp of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "%a@@%a" Pretty.print_term t1 Pretty.print_term t2]
| VGet of Term.term * Term.term
[@printer
fun fmt (t1, t2) ->
Fmt.pf fmt "%a[%a]" Pretty.print_term t1 Pretty.print_term t2]
[@@deriving show]
type caisar_env = {
dataset_ty : Ty.ty;
caisar_op_of_ls : caisar_op Term.Hls.t;
ls_of_caisar_op : (caisar_op, Term.lsymbol) Hashtbl.t;
cwd : string;
}
let ls_of_caisar_op engine op =
let ls_of_caisar_op engine op ty =
let caisar_env = CRE.user_env engine in
Fmt.pr "ls_of_caisar_op: %a@." pp_caisar_op op;
Option.iter ty ~f:(Fmt.pr "ty: %a@." Pretty.print_ty);
Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () ->
let id = Ident.id_fresh "caisar_op" in
let ty = match op with Dataset _ -> caisar_env.dataset_ty in
let ls = Term.create_fsymbol id [] ty in
let ls = Term.create_lsymbol id [] ty in
Fmt.pr "ls: %a@." Pretty.print_ls ls;
Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls;
Term.Hls.add caisar_env.caisar_op_of_ls ls op;
ls)
......@@ -48,14 +65,11 @@ let caisar_op_of_ls engine ls =
let caisar_env = CRE.user_env engine in
Term.Hls.find caisar_env.caisar_op_of_ls ls
let term_of_caisar_op engine op =
Term.t_app_infer (ls_of_caisar_op engine op) []
let term_of_caisar_op engine op ty =
Term.t_app_infer (ls_of_caisar_op engine op ty) []
let caisar_env env cwd =
let th = Env.read_theory env [ "caisar" ] "Interpretation" in
let ts_dataset = Theory.ns_find_ts th.Theory.th_export [ "dataset" ] in
let caisar_env _env cwd =
{
dataset_ty = Ty.ty_app ts_dataset [];
ls_of_caisar_op = Hashtbl.Poly.create ();
caisar_op_of_ls = Term.Hls.create 10;
cwd;
......@@ -65,33 +79,150 @@ let print_caisar_op fmt caisar_env =
Pp.print_iter2 Term.Hls.iter Pp.newline Pp.comma Pretty.print_ls pp_caisar_op
fmt caisar_env.caisar_op_of_ls
let compute_size_of_dataset ~cwd s =
let d = Caml.Filename.concat cwd s in
Array.length (Caml.Sys.readdir d)
let builtin_caisar : caisar_env CRE.built_in_theories list =
let open_dataset : _ CRE.builtin =
fun engine _ l _ ->
match l with
| [ Term { t_node = Tconst (ConstStr dataset); _ } ] ->
Term (term_of_caisar_op engine (Dataset { dataset }))
| _ -> invalid_arg "We want a string! ;)"
let error_message ls =
Fmt.str "Invalid arguments for '%a'" Pretty.print_ls ls
in
let size : _ CRE.builtin =
fun engine _ l _ ->
match l with
(* Vector *)
let vget : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.vget: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (VGet (t1, t2)) ty)
| [
Term ({ t_node = Tapp (ls, _); _ } as t1);
Term ({ t_node = Tconst (ConstInt i); _ } as t2);
] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
let t_features, t_label =
let row =
match caisar_op_of_ls engine ls with
| Dataset (CSV csv) -> List.nth_exn csv (Number.to_small_integer i)
| Data _ | Classifier _ | ClassifierApp (_, _) | VGet (_, _) ->
assert false
in
let label, features =
match row with
| [] | [ _ ] -> assert false
| label :: features -> (label, features)
in
let ty_features =
match ty with
| Some { ty_node = Tyapp (_, [ a; _ ]); _ } -> Some a
| _ -> assert false
in
( term_of_caisar_op engine (Data features) ty_features,
Term.t_int_const (BigInt.of_int (Int.of_string label)) )
in
Term (Term.t_tuple [ t_features; t_label ])
| _ -> invalid_arg (error_message ls)
in
let length : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.length: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [ Term { t_node = Tapp (ls, []); _ } ] -> (
match caisar_op_of_ls engine ls with
| Dataset { dataset } ->
let cwd = (CRE.user_env engine).cwd in
Int (BigInt.of_int (compute_size_of_dataset ~cwd dataset)))
| _ -> invalid_arg "We want a string! ;)"
| Dataset (CSV csv) -> Int (BigInt.of_int (Csv.lines csv))
| Data _ | Classifier _ | ClassifierApp _ | VGet _ -> assert false)
| _ -> invalid_arg (error_message ls)
in
(* Tensor *)
(* let valid_index : _ CRE.builtin = *)
(* fun _engine ls _vl ty -> *)
(* Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls *)
(* Fmt.(option ~none:nop Pretty.print_ty) *)
(* ty; *)
(* Term Term.t_true *)
(* in *)
(* let equal_shape : _ CRE.builtin = *)
(* fun _engine ls _vl ty -> *)
(* Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls *)
(* Fmt.(option ~none:nop Pretty.print_ty) *)
(* ty; *)
(* Term Term.t_true *)
(* in *)
(* Classifier *)
let read_classifier : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.read_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term { t_node = Tconst (ConstStr classifier); _ };
Term { t_node = Tapp ({ ls_name = { id_string = "NNet"; _ }; _ }, []); _ };
] ->
let cwd = (CRE.user_env engine).cwd in
let caisar_op =
let filename = Caml.Filename.concat cwd classifier in
Classifier filename
in
Term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls)
in
let apply_classifier : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.apply_classifier: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [ Term ({ t_node = Tvar _; _ } as t1); Term t2 ] ->
Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2;
Term (term_of_caisar_op engine (ClassifierApp (t1, t2)) ty)
| _ -> invalid_arg (error_message ls)
in
(* Dataset *)
let read_dataset : _ CRE.builtin =
fun engine ls vl ty ->
Fmt.pr "--@.read_dataset: ls:%a , ty:%a@." Pretty.print_ls ls
Fmt.(option ~none:nop Pretty.print_ty)
ty;
match vl with
| [
Term { t_node = Tconst (ConstStr dataset); _ };
Term { t_node = Tapp ({ ls_name = { id_string = "CSV"; _ }; _ }, []); _ };
] ->
let { cwd; _ } = CRE.user_env engine in
let caisar_op =
let filename = Caml.Filename.concat cwd dataset in
let dataset = CSV (Csv.load filename) in
Dataset dataset
in
Term (term_of_caisar_op engine caisar_op ty)
| _ -> invalid_arg (error_message ls)
in
[
( [ "caisar" ],
"Interpretation",
( [ "interpretation" ],
"Vector",
[],
[ (Ident.op_get "" (* ([]) *), None, vget); ("length", None, length) ] );
(* ( [ "interpretation" ], *)
(* "Tensor", *)
(* [], *)
(* [ ("valid_index", None, valid_index); ("equal_shape", None, equal_shape)
] *)
(* ); *)
( [ "interpretation" ],
"Classifier",
[],
[
("read_classifier", None, read_classifier);
(Ident.op_infix "@@", None, apply_classifier);
] );
( [ "interpretation" ],
"Dataset",
[],
[ ("open_dataset", None, open_dataset); ("size", None, size) ] );
[ ("read_dataset", None, read_dataset) ] );
]
let interpret_task ~cwd env task =
......@@ -107,6 +238,7 @@ let interpret_task ~cwd env task =
in
let engine = CRE.create params env known_map caisar_env builtin_caisar in
let f = Task.task_goal_fmla task in
Fmt.pr "TERM: %a@." Pretty.print_term f;
let f = CRE.normalize ~limit:1000 engine Term.Mvs.empty f in
Fmt.pr "%a : %a@.%a@." Pretty.print_pr (Task.task_goal task) Pretty.print_term
f print_caisar_op caisar_env
......
......@@ -803,6 +803,7 @@ let rec reduce engine c =
| [], (Kcase _, _) :: _ -> assert false
| (Int _ | Real _) :: _, (Kcase _, _) :: _ -> assert false
| Term t1 :: st, (Kcase (tbl, sigma), orig) :: rem ->
Fmt.pr "reduce_match@%a@." Pretty.print_term t1;
reduce_match st t1 ~orig tbl sigma rem
| ( ([] | [ _ ] | (Int _ | Real _) :: _ | Term _ :: (Int _ | Real _) :: _),
(Kbinop _, _) :: _ ) ->
......
......@@ -2,5 +2,5 @@
(section
(site
(caisar stdlib)))
(files caisar.mlw)
(files caisar.mlw interpretation.mlw)
(package caisar))
......@@ -31,15 +31,12 @@ theory Vector
function map (v: vector 'a) (f: 'a -> 'b) : vector 'b
function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c
function fold (v: vector 'a) (acc: 'b) (f: 'b -> 'a -> 'b) : 'b
function fold2 (v1: vector 'a) (v2: vector 'b) (acc: 'c) (f: 'c -> 'a -> 'b -> 'c) : 'c
scope L
predicate forall_ (v: vector 'a) (f: 'a -> bool) =
fold v True (fun acc e -> acc /\ f e)
forall i: int. 0 <= i < length v -> f v[i]
predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) =
length(v1) = length(v2) -> fold2 v1 v2 True (fun acc e1 e2 -> acc /\ f e1 e2)
length(v1) = length(v2) -> forall i: int. 0 <= i < length v1 -> f v1[i] v2[i]
function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b =
map v f
......@@ -79,9 +76,9 @@ theory Dataset
use Tensor
type dataset 'a 'b = vector ('a, 'b)
type kind = CSV
type format = CSV
function read_dataset (f: string) (k: kind) : dataset 'a 'b
function read_dataset (f: string) (k: format) : dataset 'a 'b
scope L
predicate forall_ (d: dataset 'a 'b) (f: 'a -> 'b -> bool) =
......
Test interpret
$ caisar interpret -L . --format whyml - 2>&1 <<EOF | ./filter_tmpdir.sh
> theory T
> use caisar.Interpretation
> use int.Int
>
> goal G1: 1+1=2
>
> goal G2: 1+1=3
>
> goal G3: size (open_dataset "datasets/a") = 2
>
> goal G4:
> let dataset = open_dataset "datasets/a" in
> size dataset = 2
>
> predicate robust (i: input)
>
> goal G5:
> let dataset = open_dataset "datasets/a" in
> forall_ dataset (fun i -> robust i)
>
> goal G6:
> let dataset = open_dataset "datasets/a" in
> forall i:int. i=1+(size dataset) -> i < 4
> end
> EOF
G1 : true
G2 : false
G3 : true
caisar_op,
(Interpretation.Dataset { Interpretation.dataset = "datasets/a" })
G4 : true
caisar_op1,
(Interpretation.Dataset { Interpretation.dataset = "datasets/a" })
G5 : robust (get 0 caisar_op2) /\ robust (get 1 caisar_op2)
caisar_op2,
(Interpretation.Dataset { Interpretation.dataset = "datasets/a" })
G6 : forall i:int. i = 3 -> i < 4
caisar_op3,
(Interpretation.Dataset { Interpretation.dataset = "datasets/a" })
Test interpret on mnist
Test interpret on dataset
$ cat - > dataset.csv << EOF
> 1,0.0,1.0,0.784313725,0.019607843,0.776470588
> 0,1.0,0.0,0.019607843,0.776470588,0.784313725
......@@ -9,26 +9,26 @@ Test interpret on mnist
> use ieee_float.Float64
> use bool.Bool
> use int.Int
> use Vector
> use Tensor
> use Classifier
> use Dataset
>
> use interpretation.Vector
> use interpretation.Tensor
> use interpretation.Classifier
> use interpretation.Dataset
>
> type image = tensor t
> type label_ = int
>
>
> predicate valid_image (i: image) =
> forall v: index. valid_index i v -> (0.0: t) .<= i#v .<= (1.0: t)
>
>
> predicate valid_label (l: label_) = 0 <= l <= 1
>
>
> predicate advises (c: classifier) (i: image) (l: label_) =
> valid_label l ->
> forall j: int. valid_label j /\ j <> l -> (c@@i)[l] .> (c@@i)[j]
>
>
> predicate bounded_by_epsilon (i: image) (eps: t) =
> forall v: index. valid_index i v -> .- eps .<= i#v .<= eps
>
>
> predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) =
> forall perturbed_image: image.
> valid_image perturbed_image ->
......@@ -36,10 +36,10 @@ Test interpret on mnist
> let p = perturbed_image - i in
> bounded_by_epsilon p eps ->
> advises c perturbed_image l
>
>
> predicate robust (c: classifier) (d: dataset image label_) (eps: t) =
> Dataset.L.forall_ d (robust_around c eps)
>
>
> goal G:
> let classifier = read_classifier "TestNetwork.nnet" NNet in
> let dataset = read_dataset "dataset.csv" CSV in
......@@ -47,3 +47,57 @@ Test interpret on mnist
> robust classifier dataset eps
> end
> EOF
G : match caisar_op[0] with
| a, b ->
((fun (y2:tensor t) (y3:int) ->
forall perturbed_image:tensor t.
(forall v:vector int.
valid_index perturbed_image v ->
le (0.0:t) (perturbed_image # v) /\
le (perturbed_image # v) (1.0:t)) ->
equal_shape y2 perturbed_image ->
(forall v:vector int.
valid_index (perturbed_image - y2) v ->
le (neg (0.375:t)) ((perturbed_image - y2) # v) /\
le ((perturbed_image - y2) # v) (0.375:t)) ->
(0 < y3 \/ 0 = y3) /\ (y3 < 1 \/ y3 = 1) ->
(forall j:int.
((0 < j \/ 0 = j) /\ (j < 1 \/ j = 1)) /\ not j = y3 ->
lt
(read_classifier ("TestNetwork.nnet":string) NNet
@@ perturbed_image)
[j]
(read_classifier ("TestNetwork.nnet":string) NNet
@@ perturbed_image)
[y3]))
@ a)
@ b
end = True /\
match caisar_op[1] with
| a, b ->
((fun (y2:tensor t) (y3:int) ->
forall perturbed_image:tensor t.
(forall v:vector int.
valid_index perturbed_image v ->
le (0.0:t) (perturbed_image # v) /\
le (perturbed_image # v) (1.0:t)) ->
equal_shape y2 perturbed_image ->
(forall v:vector int.
valid_index (perturbed_image - y2) v ->
le (neg (0.375:t)) ((perturbed_image - y2) # v) /\
le ((perturbed_image - y2) # v) (0.375:t)) ->
(0 < y3 \/ 0 = y3) /\ (y3 < 1 \/ y3 = 1) ->
(forall j:int.
((0 < j \/ 0 = j) /\ (j < 1 \/ j = 1)) /\ not j = y3 ->
lt
(read_classifier ("TestNetwork.nnet":string) NNet
@@ perturbed_image)
[j]
(read_classifier ("TestNetwork.nnet":string) NNet
@@ perturbed_image)
[y3]))
@ a)
@ b
end = True
caisar_op,
(Interpretation.Dataset <csv>)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment