Skip to content
Snippets Groups Projects
Commit b912fcff authored by Michele Alberti's avatar Michele Alberti
Browse files

[interpretation] Modify dataset testcase to work on vectors instead.

parent 63227409
No related branches found
No related tags found
No related merge requests found
...@@ -150,11 +150,58 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -150,11 +150,58 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
match caisar_op_of_ls engine ls with match caisar_op_of_ls engine ls with
| Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv))
| Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false) | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false)
| [ Term { t_node = Tapp (ls, tl); _ } ] -> (
match caisar_op_of_ls engine ls with
| Vector n ->
assert (List.length tl = n);
int (BigInt.of_int n)
| Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
| [ Term t ] -> | [ Term t ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t ]) term (Term.t_app_infer ls [ t ])
| _ -> invalid_arg (error_message ls) | _ -> invalid_arg (error_message ls)
in in
let vminus : _ CRE.builtin =
fun engine ls vl ty ->
(* Fmt.pr "--@.vminus: ls:%a , ty:%a@." Pretty.print_ls ls *)
(* Fmt.(option ~none:nop Pretty.print_ty) *)
(* ty; *)
match vl with
| [
Term ({ t_node = Tapp (ls1, tl1); _ } as _t1);
Term ({ t_node = Tapp (ls2, _); _ } as _t2);
] -> (
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with
| Vector n, Data (D_csv data) ->
assert (n = List.length data);
let ty_cst =
match ty with
| Some { ty_node = Tyapp (_, [ ty ]); _ } -> ty
| _ -> assert false
in
let csts =
List.map data ~f:(fun d ->
let cst = const_real_of_float (Float.of_string d) in
Term.t_const cst ty_cst)
in
let minus =
(* TODO: generalize wrt the type of constants [csts]. *)
let { env; _ } = CRE.user_env engine in
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ])
in
let args =
List.map2_exn tl1 csts ~f:(fun tl c ->
(Term.t_app_infer minus [ tl; c ], ty_cst))
in
term (term_of_caisar_op ~args engine (Vector n) ty)
| _ -> assert false)
| [ Term t1; Term t2 ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t1; t2 ])
| _ -> invalid_arg (error_message ls)
in
let mapi : _ CRE.builtin = let mapi : _ CRE.builtin =
fun engine ls vl ty -> fun engine ls vl ty ->
(* Fmt.pr "--@.mapi: ls:%a , ty:%a@." Pretty.print_ls ls *) (* Fmt.pr "--@.mapi: ls:%a , ty:%a@." Pretty.print_ls ls *)
...@@ -305,6 +352,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -305,6 +352,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
[], [],
[ [
([ Ident.op_get "" ] (* ([]) *), None, vget); ([ Ident.op_get "" ] (* ([]) *), None, vget);
([ Ident.op_infix "-" ], None, vminus);
([ "length" ], None, length); ([ "length" ], None, length);
([ "L"; "mapi" ], None, mapi); ([ "L"; "mapi" ], None, mapi);
] ); ] );
...@@ -346,19 +394,27 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = ...@@ -346,19 +394,27 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
| Data (D_csv d) -> List.length d | Data (D_csv d) -> List.length d
| _ -> assert false | _ -> assert false
in in
let ty = let ty, caisar_op, id =
match vs.vs_ty with match vs.vs_ty with
| { ty_node = Tyapp (_, ty :: _); _ } -> ty | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } ->
let caisar_op, id =
if String.equal ts_name.id_string "vector"
then (Vector n, "caisar_v")
else if String.equal ts_name.id_string "tensor"
then (Tensor n, "caisar_t")
else assert false
in
(ty, caisar_op, id)
| _ -> assert false | _ -> assert false
in in
let new_quant = let new_quant =
List.init n ~f:(fun _ -> List.init n ~f:(fun _ ->
let preid = Ident.id_fresh "caisar_t" in let preid = Ident.id_fresh id in
Term.create_vsymbol preid ty) Term.create_vsymbol preid ty)
in in
let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in
let substitutions = let substitutions =
[ term_of_caisar_op ~args engine (Tensor n) (Some vs.vs_ty) ] [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ]
in in
Some { new_quant; substitutions } Some { new_quant; substitutions }
| Tapp | Tapp
......
...@@ -24,22 +24,26 @@ theory Vector ...@@ -24,22 +24,26 @@ theory Vector
use int.Int use int.Int
type vector 'a type vector 'a
type index = int
function ([]) (v: vector 'a) (i: int) : 'a function ([]) (v: vector 'a) (i: index) : 'a
function length (v: vector 'a) : int function length (v: vector 'a) : int
function (-) (v1: vector 'a) (v2: vector 'a) : vector 'a
predicate has_length (v: vector 'a) (i: int) predicate has_length (v: vector 'a) (i: int)
predicate equal_shape (v1: vector 'a) (v2: vector 'b)
predicate valid_index (v: vector 'a) (i: index) = 0 <= i < length v
scope L scope L
function mapi (v: vector 'a) (f: int -> 'a -> 'b) : vector 'b function mapi (v: vector 'a) (f: index -> 'a -> 'b) : vector 'b
function map (v: vector 'a) (f: 'a -> 'b) : vector 'b function map (v: vector 'a) (f: 'a -> 'b) : vector 'b
function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c
predicate forall_ (v: vector 'a) (f: 'a -> bool) = predicate forall_ (v: vector 'a) (f: 'a -> bool) =
forall i: int. 0 <= i < length v -> f v[i] forall i: index. valid_index v i -> f v[i]
predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) = predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) =
length(v1) = length(v2) -> forall i: int. 0 <= i < length v1 -> f v1[i] v2[i] length(v1) = length(v2) -> forall i: index. valid_index v1 i -> f v1[i] v2[i]
function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b = function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b =
map v f map v f
......
...@@ -10,24 +10,23 @@ Test interpret on dataset ...@@ -10,24 +10,23 @@ Test interpret on dataset
> use bool.Bool > use bool.Bool
> use int.Int > use int.Int
> use interpretation.Vector > use interpretation.Vector
> use interpretation.Tensor
> use interpretation.Classifier > use interpretation.Classifier
> use interpretation.Dataset > use interpretation.Dataset
> >
> type image = tensor t > type image = vector t
> type label_ = int > type label_ = int
> >
> predicate valid_image (i: image) = > predicate valid_image (i: image) =
> forall v: index. valid_index i v -> (0.0: t) .<= i#v .<= (1.0: t) > forall v: index. valid_index i v -> (0.0: t) .<= i[v] .<= (1.0: t)
> >
> predicate valid_label (l: label_) = 0 <= l <= 2 > predicate valid_label (l: label_) = 0 <= l <= 2
> >
> predicate advises (c: classifier) (i: image) (l: label_) = > predicate advises (c: classifier) (i: image) (l: label_) =
> valid_label l -> > valid_label l ->
> forall j: int. valid_label j -> j <> l -> (c@@i)[l] .> (c@@i)[j] > forall j: int. valid_label j -> j <> l -> (c%%i)[l] .> (c%%i)[j]
> >
> predicate bounded_by_epsilon (i: image) (eps: t) = > predicate bounded_by_epsilon (i: image) (eps: t) =
> forall v: index. valid_index i v -> .- eps .<= i#v .<= eps > forall v: index. valid_index i v -> .- eps .<= i[v] .<= eps
> >
> predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) = > predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) =
> forall perturbed_image: image. > forall perturbed_image: image.
...@@ -47,111 +46,107 @@ Test interpret on dataset ...@@ -47,111 +46,107 @@ Test interpret on dataset
> robust classifier dataset eps > robust classifier dataset eps
> end > end
> EOF > EOF
G : (forall caisar_t:t, caisar_t1:t, caisar_t2:t, caisar_t3:t, caisar_t4:t. G : (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
((((le (0.0:t) caisar_t4 /\ le caisar_t4 (1.0:t)) /\ ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\
le (0.0:t) caisar_t3 /\ le caisar_t3 (1.0:t)) /\ le (0.0:t) caisar_v1 /\ le caisar_v1 (1.0:t)) /\
le (0.0:t) caisar_t2 /\ le caisar_t2 (1.0:t)) /\ le (0.0:t) caisar_v2 /\ le caisar_v2 (1.0:t)) /\
le (0.0:t) caisar_t1 /\ le caisar_t1 (1.0:t)) /\ le (0.0:t) caisar_v3 /\ le caisar_v3 (1.0:t)) /\
le (0.0:t) caisar_t /\ le caisar_t (1.0:t) -> le (0.0:t) caisar_v4 /\ le caisar_v4 (1.0:t) ->
((((le (neg (0.375:t)) ((((le (neg (0.375:t)) (sub RNE caisar_v (0.0:t)) /\
(sub RNE caisar_t4 le (sub RNE caisar_v (0.0:t)) (0.375:t)) /\
(0.776470588000000017103729987866245210170745849609375:t)) /\ le (neg (0.375:t)) (sub RNE caisar_v1 (1.0:t)) /\
le le (sub RNE caisar_v1 (1.0:t)) (0.375:t)) /\
(sub RNE caisar_t4
(0.776470588000000017103729987866245210170745849609375:t))
(0.375:t)) /\
le (neg (0.375:t))
(sub RNE caisar_t3
(0.01960784299999999980013143385804141871631145477294921875:t)) /\
le
(sub RNE caisar_t3
(0.01960784299999999980013143385804141871631145477294921875:t))
(0.375:t)) /\
le (neg (0.375:t)) le (neg (0.375:t))
(sub RNE caisar_t2 (sub RNE caisar_v2
(0.78431372499999996161790249971090815961360931396484375:t)) /\ (0.78431372499999996161790249971090815961360931396484375:t)) /\
le le
(sub RNE caisar_t2 (sub RNE caisar_v2
(0.78431372499999996161790249971090815961360931396484375:t)) (0.78431372499999996161790249971090815961360931396484375:t))
(0.375:t)) /\ (0.375:t)) /\
le (neg (0.375:t)) (sub RNE caisar_t1 (1.0:t)) /\ le (neg (0.375:t))
le (sub RNE caisar_t1 (1.0:t)) (0.375:t)) /\ (sub RNE caisar_v3
le (neg (0.375:t)) (sub RNE caisar_t (0.0:t)) /\ (0.01960784299999999980013143385804141871631145477294921875:t)) /\
le (sub RNE caisar_t (0.0:t)) (0.375:t) -> le
(sub RNE caisar_v3
(0.01960784299999999980013143385804141871631145477294921875:t))
(0.375:t)) /\
le (neg (0.375:t))
(sub RNE caisar_v4
(0.776470588000000017103729987866245210170745849609375:t)) /\
le
(sub RNE caisar_v4
(0.776470588000000017103729987866245210170745849609375:t))
(0.375:t) ->
lt lt
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[0] [0]
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[1] /\ [1] /\
lt lt
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[2] [2]
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[1]) /\ [1]) /\
(forall caisar_t:t, caisar_t1:t, caisar_t2:t, caisar_t3:t, caisar_t4:t. (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
((((le (0.0:t) caisar_t4 /\ le caisar_t4 (1.0:t)) /\ ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\
le (0.0:t) caisar_t3 /\ le caisar_t3 (1.0:t)) /\ le (0.0:t) caisar_v1 /\ le caisar_v1 (1.0:t)) /\
le (0.0:t) caisar_t2 /\ le caisar_t2 (1.0:t)) /\ le (0.0:t) caisar_v2 /\ le caisar_v2 (1.0:t)) /\
le (0.0:t) caisar_t1 /\ le caisar_t1 (1.0:t)) /\ le (0.0:t) caisar_v3 /\ le caisar_v3 (1.0:t)) /\
le (0.0:t) caisar_t /\ le caisar_t (1.0:t) -> le (0.0:t) caisar_v4 /\ le caisar_v4 (1.0:t) ->
((((le (neg (0.375:t)) ((((le (neg (0.375:t)) (sub RNE caisar_v (1.0:t)) /\
(sub RNE caisar_t4 le (sub RNE caisar_v (1.0:t)) (0.375:t)) /\
(0.78431372499999996161790249971090815961360931396484375:t)) /\ le (neg (0.375:t)) (sub RNE caisar_v1 (0.0:t)) /\
le le (sub RNE caisar_v1 (0.0:t)) (0.375:t)) /\
(sub RNE caisar_t4
(0.78431372499999996161790249971090815961360931396484375:t))
(0.375:t)) /\
le (neg (0.375:t))
(sub RNE caisar_t3
(0.776470588000000017103729987866245210170745849609375:t)) /\
le
(sub RNE caisar_t3
(0.776470588000000017103729987866245210170745849609375:t))
(0.375:t)) /\
le (neg (0.375:t)) le (neg (0.375:t))
(sub RNE caisar_t2 (sub RNE caisar_v2
(0.01960784299999999980013143385804141871631145477294921875:t)) /\ (0.01960784299999999980013143385804141871631145477294921875:t)) /\
le le
(sub RNE caisar_t2 (sub RNE caisar_v2
(0.01960784299999999980013143385804141871631145477294921875:t)) (0.01960784299999999980013143385804141871631145477294921875:t))
(0.375:t)) /\ (0.375:t)) /\
le (neg (0.375:t)) (sub RNE caisar_t1 (0.0:t)) /\ le (neg (0.375:t))
le (sub RNE caisar_t1 (0.0:t)) (0.375:t)) /\ (sub RNE caisar_v3
le (neg (0.375:t)) (sub RNE caisar_t (1.0:t)) /\ (0.776470588000000017103729987866245210170745849609375:t)) /\
le (sub RNE caisar_t (1.0:t)) (0.375:t) -> le
(sub RNE caisar_v3
(0.776470588000000017103729987866245210170745849609375:t))
(0.375:t)) /\
le (neg (0.375:t))
(sub RNE caisar_v4
(0.78431372499999996161790249971090815961360931396484375:t)) /\
le
(sub RNE caisar_v4
(0.78431372499999996161790249971090815961360931396484375:t))
(0.375:t) ->
lt lt
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[1] [1]
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[0] /\ [0] /\
lt lt
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[2] [2]
(caisar_op (caisar_op
@@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
[0]) [0])
caisar_op, caisar_op2,
(Interpretation.Classifier
"$TESTCASE_ROOT/TestNetwork.nnet")
caisar_op1, (Interpretation.Tensor 5)
caisar_op2, (Interpretation.Dataset <csv>)
caisar_op3, (Interpretation.Index (Interpretation.I_csv 4))
caisar_op4, (Interpretation.Index (Interpretation.I_csv 3))
caisar_op5,
(Interpretation.Data (Interpretation.Data
(Interpretation.D_csv (Interpretation.D_csv
["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
caisar_op6, (Interpretation.Index (Interpretation.I_csv 2)) caisar_op3,
caisar_op7, (Interpretation.Index (Interpretation.I_csv 1))
caisar_op8, (Interpretation.Index (Interpretation.I_csv 0))
caisar_op9,
(Interpretation.Data (Interpretation.Data
(Interpretation.D_csv (Interpretation.D_csv
["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
caisar_op1, (Interpretation.Vector 5)
caisar_op,
(Interpretation.Classifier
"$TESTCASE_ROOT/TestNetwork.nnet")
caisar_op4,
(Interpretation.Dataset <csv>)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment