diff --git a/src/interpretation.ml b/src/interpretation.ml index 69d488bd1257a89e6457d1256cdec0d60bb3a688..238359d40992c6637f25fe249fb9fe4fffb16145 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -157,7 +157,9 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = | [ Term { t_node = Tapp (ls, []); _ } ] -> ( match caisar_op_of_ls engine ls with | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) - | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false) + | Vector v -> int (BigInt.of_int (Language.lookup_vector v)) + | Data (D_csv data) -> int (BigInt.of_int (List.length data)) + | Classifier _ | Tensor _ | Index _ -> assert false) | [ Term { t_node = Tapp (ls, tl); _ } ] -> ( match caisar_op_of_ls engine ls with | Vector v -> @@ -410,29 +412,19 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = | Data (D_csv d) -> List.length d | _ -> assert false in - let ty, caisar_op, id = + let ty = match vs.vs_ty with - | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } -> - let caisar_op, id = - if String.equal ts_name.id_string "vector" - then - let { env; _ } = CRE.user_env engine in - (Vector (Language.create_vector env n), "caisar_v") - else if String.equal ts_name.id_string "tensor" - then (Tensor n, "caisar_t") - else assert false - in - (ty, caisar_op, id) + | { ty_node = Tyapp (_, ty :: _); _ } -> ty | _ -> assert false in let new_quant = List.init n ~f:(fun _ -> - let preid = Ident.id_fresh id in + let preid = Ident.id_fresh "caisar_t" in Term.create_vsymbol preid ty) in let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in let substitutions = - [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ] + [ term_of_caisar_op ~args engine (Tensor n) (Some vs.vs_ty) ] in Some { new_quant; substitutions } | Tapp diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw index a6b53eff4d6e41d22667f01368ed1800b5a90d52..ab3b2779686f518d50f8e11ab549bc9d6f8dafec 100644 --- a/stdlib/interpretation.mlw +++ b/stdlib/interpretation.mlw @@ -31,7 +31,6 @@ theory Vector function (-) (v1: vector 'a) (v2: vector 'a) : vector 'a predicate has_length (v: vector 'a) (i: int) - predicate equal_shape (v1: vector 'a) (v2: vector 'b) predicate valid_index (v: vector 'a) (i: index) = 0 <= i < length v scope L diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 4336ff254541537047f71404c605bc1c8a7e681e..2fd6c14138e287a38590cab657fc2c46ae4d6da1 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -30,7 +30,7 @@ Test interpret on dataset > > predicate robust_around (c: classifier) (eps: t) (i: image) (l: label_) = > forall perturbed_image: image. - > equal_shape i perturbed_image -> + > has_length perturbed_image (length i) -> > valid_image perturbed_image -> > let perturbation = perturbed_image - i in > bounded_by_epsilon perturbation eps -> @@ -129,16 +129,15 @@ Test interpret on dataset (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) caisar_op1, - (Interpretation.Data - (Interpretation.D_csv - ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) - vector, (Interpretation.Vector 5) - caisar_op2, (Interpretation.Data (Interpretation.D_csv ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) + vector, (Interpretation.Vector 5) caisar_op, (Interpretation.Classifier "$TESTCASE_ROOT/TestNetwork.nnet") + caisar_op2, (Interpretation.Dataset <csv>) caisar_op3, - (Interpretation.Dataset <csv>) + (Interpretation.Data + (Interpretation.D_csv + ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))