diff --git a/src/interpretation.ml b/src/interpretation.ml index df78bcaa27fc48de7a3ad041626107ace606bdb7..21e6736f1739f1453a319d81863ffc1bfb10cfe8 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -150,11 +150,58 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = match caisar_op_of_ls engine ls with | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false) + | [ Term { t_node = Tapp (ls, tl); _ } ] -> ( + match caisar_op_of_ls engine ls with + | Vector n -> + assert (List.length tl = n); + int (BigInt.of_int n) + | Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) | [ Term t ] -> (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) term (Term.t_app_infer ls [ t ]) | _ -> invalid_arg (error_message ls) in + let vminus : _ CRE.builtin = + fun engine ls vl ty -> + (* Fmt.pr "--@.vminus: ls:%a , ty:%a@." Pretty.print_ls ls *) + (* Fmt.(option ~none:nop Pretty.print_ty) *) + (* ty; *) + match vl with + | [ + Term ({ t_node = Tapp (ls1, tl1); _ } as _t1); + Term ({ t_node = Tapp (ls2, _); _ } as _t2); + ] -> ( + (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) + match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with + | Vector n, Data (D_csv data) -> + assert (n = List.length data); + let ty_cst = + match ty with + | Some { ty_node = Tyapp (_, [ ty ]); _ } -> ty + | _ -> assert false + in + let csts = + List.map data ~f:(fun d -> + let cst = const_real_of_float (Float.of_string d) in + Term.t_const cst ty_cst) + in + let minus = + (* TODO: generalize wrt the type of constants [csts]. *) + let { env; _ } = CRE.user_env engine in + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ]) + in + let args = + List.map2_exn tl1 csts ~f:(fun tl c -> + (Term.t_app_infer minus [ tl; c ], ty_cst)) + in + term (term_of_caisar_op ~args engine (Vector n) ty) + | _ -> assert false) + | [ Term t1; Term t2 ] -> + (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) + term (Term.t_app_infer ls [ t1; t2 ]) + | _ -> invalid_arg (error_message ls) + in let mapi : _ CRE.builtin = fun engine ls vl ty -> (* Fmt.pr "--@.mapi: ls:%a , ty:%a@." Pretty.print_ls ls *) @@ -305,6 +352,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = [], [ ([ Ident.op_get "" ] (* ([]) *), None, vget); + ([ Ident.op_infix "-" ], None, vminus); ([ "length" ], None, length); ([ "L"; "mapi" ], None, mapi); ] ); @@ -346,19 +394,27 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = | Data (D_csv d) -> List.length d | _ -> assert false in - let ty = + let ty, caisar_op, id = match vs.vs_ty with - | { ty_node = Tyapp (_, ty :: _); _ } -> ty + | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } -> + let caisar_op, id = + if String.equal ts_name.id_string "vector" + then (Vector n, "caisar_v") + else if String.equal ts_name.id_string "tensor" + then (Tensor n, "caisar_t") + else assert false + in + (ty, caisar_op, id) | _ -> assert false in let new_quant = List.init n ~f:(fun _ -> - let preid = Ident.id_fresh "caisar_t" in + let preid = Ident.id_fresh id in Term.create_vsymbol preid ty) in let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in let substitutions = - [ term_of_caisar_op ~args engine (Tensor n) (Some vs.vs_ty) ] + [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ] in Some { new_quant; substitutions } | Tapp diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw index f22eec4063ad5a1e9f9d53c152300dc85f374890..a6b53eff4d6e41d22667f01368ed1800b5a90d52 100644 --- a/stdlib/interpretation.mlw +++ b/stdlib/interpretation.mlw @@ -24,22 +24,26 @@ theory Vector use int.Int type vector 'a + type index = int - function ([]) (v: vector 'a) (i: int) : 'a + function ([]) (v: vector 'a) (i: index) : 'a function length (v: vector 'a) : int + function (-) (v1: vector 'a) (v2: vector 'a) : vector 'a predicate has_length (v: vector 'a) (i: int) + predicate equal_shape (v1: vector 'a) (v2: vector 'b) + predicate valid_index (v: vector 'a) (i: index) = 0 <= i < length v scope L - function mapi (v: vector 'a) (f: int -> 'a -> 'b) : vector 'b + function mapi (v: vector 'a) (f: index -> 'a -> 'b) : vector 'b function map (v: vector 'a) (f: 'a -> 'b) : vector 'b function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c predicate forall_ (v: vector 'a) (f: 'a -> bool) = - forall i: int. 0 <= i < length v -> f v[i] + forall i: index. valid_index v i -> f v[i] predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) = - length(v1) = length(v2) -> forall i: int. 0 <= i < length v1 -> f v1[i] v2[i] + length(v1) = length(v2) -> forall i: index. valid_index v1 i -> f v1[i] v2[i] function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b = map v f diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 5303d3ba8087aded7f1d83d36a10e920d707c955..75b13c7c4e8ccbc2125bd8b7c99b4ae559454c6d 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -10,24 +10,23 @@ Test interpret on dataset > use bool.Bool > use int.Int > use interpretation.Vector - > use interpretation.Tensor > use interpretation.Classifier > use interpretation.Dataset > - > type image = tensor t + > type image = vector t > type label_ = int > > predicate valid_image (i: image) = - > forall v: index. valid_index i v -> (0.0: t) .<= i#v .<= (1.0: t) + > forall v: index. valid_index i v -> (0.0: t) .<= i[v] .<= (1.0: t) > > predicate valid_label (l: label_) = 0 <= l <= 2 > > predicate advises (c: classifier) (i: image) (l: label_) = > valid_label l -> - > forall j: int. valid_label j -> j <> l -> (c@@i)[l] .> (c@@i)[j] + > forall j: int. valid_label j -> j <> l -> (c%%i)[l] .> (c%%i)[j] > > predicate bounded_by_epsilon (i: image) (eps: t) = - > forall v: index. valid_index i v -> .- eps .<= i#v .<= eps + > forall v: index. valid_index i v -> .- eps .<= i[v] .<= eps > > predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) = > forall perturbed_image: image. @@ -47,111 +46,107 @@ Test interpret on dataset > robust classifier dataset eps > end > EOF - G : (forall caisar_t:t, caisar_t1:t, caisar_t2:t, caisar_t3:t, caisar_t4:t. - ((((le (0.0:t) caisar_t4 /\ le caisar_t4 (1.0:t)) /\ - le (0.0:t) caisar_t3 /\ le caisar_t3 (1.0:t)) /\ - le (0.0:t) caisar_t2 /\ le caisar_t2 (1.0:t)) /\ - le (0.0:t) caisar_t1 /\ le caisar_t1 (1.0:t)) /\ - le (0.0:t) caisar_t /\ le caisar_t (1.0:t) -> - ((((le (neg (0.375:t)) - (sub RNE caisar_t4 - (0.776470588000000017103729987866245210170745849609375:t)) /\ - le - (sub RNE caisar_t4 - (0.776470588000000017103729987866245210170745849609375:t)) - (0.375:t)) /\ - le (neg (0.375:t)) - (sub RNE caisar_t3 - (0.01960784299999999980013143385804141871631145477294921875:t)) /\ - le - (sub RNE caisar_t3 - (0.01960784299999999980013143385804141871631145477294921875:t)) - (0.375:t)) /\ + G : (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. + ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\ + le (0.0:t) caisar_v1 /\ le caisar_v1 (1.0:t)) /\ + le (0.0:t) caisar_v2 /\ le caisar_v2 (1.0:t)) /\ + le (0.0:t) caisar_v3 /\ le caisar_v3 (1.0:t)) /\ + le (0.0:t) caisar_v4 /\ le caisar_v4 (1.0:t) -> + ((((le (neg (0.375:t)) (sub RNE caisar_v (0.0:t)) /\ + le (sub RNE caisar_v (0.0:t)) (0.375:t)) /\ + le (neg (0.375:t)) (sub RNE caisar_v1 (1.0:t)) /\ + le (sub RNE caisar_v1 (1.0:t)) (0.375:t)) /\ le (neg (0.375:t)) - (sub RNE caisar_t2 + (sub RNE caisar_v2 (0.78431372499999996161790249971090815961360931396484375:t)) /\ le - (sub RNE caisar_t2 + (sub RNE caisar_v2 (0.78431372499999996161790249971090815961360931396484375:t)) (0.375:t)) /\ - le (neg (0.375:t)) (sub RNE caisar_t1 (1.0:t)) /\ - le (sub RNE caisar_t1 (1.0:t)) (0.375:t)) /\ - le (neg (0.375:t)) (sub RNE caisar_t (0.0:t)) /\ - le (sub RNE caisar_t (0.0:t)) (0.375:t) -> + le (neg (0.375:t)) + (sub RNE caisar_v3 + (0.01960784299999999980013143385804141871631145477294921875:t)) /\ + le + (sub RNE caisar_v3 + (0.01960784299999999980013143385804141871631145477294921875:t)) + (0.375:t)) /\ + le (neg (0.375:t)) + (sub RNE caisar_v4 + (0.776470588000000017103729987866245210170745849609375:t)) /\ + le + (sub RNE caisar_v4 + (0.776470588000000017103729987866245210170745849609375:t)) + (0.375:t) -> lt (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1]) /\ - (forall caisar_t:t, caisar_t1:t, caisar_t2:t, caisar_t3:t, caisar_t4:t. - ((((le (0.0:t) caisar_t4 /\ le caisar_t4 (1.0:t)) /\ - le (0.0:t) caisar_t3 /\ le caisar_t3 (1.0:t)) /\ - le (0.0:t) caisar_t2 /\ le caisar_t2 (1.0:t)) /\ - le (0.0:t) caisar_t1 /\ le caisar_t1 (1.0:t)) /\ - le (0.0:t) caisar_t /\ le caisar_t (1.0:t) -> - ((((le (neg (0.375:t)) - (sub RNE caisar_t4 - (0.78431372499999996161790249971090815961360931396484375:t)) /\ - le - (sub RNE caisar_t4 - (0.78431372499999996161790249971090815961360931396484375:t)) - (0.375:t)) /\ - le (neg (0.375:t)) - (sub RNE caisar_t3 - (0.776470588000000017103729987866245210170745849609375:t)) /\ - le - (sub RNE caisar_t3 - (0.776470588000000017103729987866245210170745849609375:t)) - (0.375:t)) /\ + (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. + ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\ + le (0.0:t) caisar_v1 /\ le caisar_v1 (1.0:t)) /\ + le (0.0:t) caisar_v2 /\ le caisar_v2 (1.0:t)) /\ + le (0.0:t) caisar_v3 /\ le caisar_v3 (1.0:t)) /\ + le (0.0:t) caisar_v4 /\ le caisar_v4 (1.0:t) -> + ((((le (neg (0.375:t)) (sub RNE caisar_v (1.0:t)) /\ + le (sub RNE caisar_v (1.0:t)) (0.375:t)) /\ + le (neg (0.375:t)) (sub RNE caisar_v1 (0.0:t)) /\ + le (sub RNE caisar_v1 (0.0:t)) (0.375:t)) /\ le (neg (0.375:t)) - (sub RNE caisar_t2 + (sub RNE caisar_v2 (0.01960784299999999980013143385804141871631145477294921875:t)) /\ le - (sub RNE caisar_t2 + (sub RNE caisar_v2 (0.01960784299999999980013143385804141871631145477294921875:t)) (0.375:t)) /\ - le (neg (0.375:t)) (sub RNE caisar_t1 (0.0:t)) /\ - le (sub RNE caisar_t1 (0.0:t)) (0.375:t)) /\ - le (neg (0.375:t)) (sub RNE caisar_t (1.0:t)) /\ - le (sub RNE caisar_t (1.0:t)) (0.375:t) -> + le (neg (0.375:t)) + (sub RNE caisar_v3 + (0.776470588000000017103729987866245210170745849609375:t)) /\ + le + (sub RNE caisar_v3 + (0.776470588000000017103729987866245210170745849609375:t)) + (0.375:t)) /\ + le (neg (0.375:t)) + (sub RNE caisar_v4 + (0.78431372499999996161790249971090815961360931396484375:t)) /\ + le + (sub RNE caisar_v4 + (0.78431372499999996161790249971090815961360931396484375:t)) + (0.375:t) -> lt (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] /\ lt (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] (caisar_op - @@ caisar_op1 caisar_t caisar_t1 caisar_t2 caisar_t3 caisar_t4) + %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) - caisar_op, - (Interpretation.Classifier - "$TESTCASE_ROOT/TestNetwork.nnet") - caisar_op1, (Interpretation.Tensor 5) - caisar_op2, (Interpretation.Dataset <csv>) - caisar_op3, (Interpretation.Index (Interpretation.I_csv 4)) - caisar_op4, (Interpretation.Index (Interpretation.I_csv 3)) - caisar_op5, + caisar_op2, (Interpretation.Data (Interpretation.D_csv - ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) - caisar_op6, (Interpretation.Index (Interpretation.I_csv 2)) - caisar_op7, (Interpretation.Index (Interpretation.I_csv 1)) - caisar_op8, (Interpretation.Index (Interpretation.I_csv 0)) - caisar_op9, + ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) + caisar_op3, (Interpretation.Data (Interpretation.D_csv - ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) + ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) + caisar_op1, (Interpretation.Vector 5) + caisar_op, + (Interpretation.Classifier + "$TESTCASE_ROOT/TestNetwork.nnet") + caisar_op4, + (Interpretation.Dataset <csv>)