diff --git a/src/interpretation.ml b/src/interpretation.ml index 21e6736f1739f1453a319d81863ffc1bfb10cfe8..69d488bd1257a89e6457d1256cdec0d60bb3a688 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -31,12 +31,17 @@ type classifier = string [@@deriving show] type data = D_csv of string list [@@deriving show] type index = I_csv of int [@@deriving show] +type vector = + (Language.vector + [@printer fun fmt v -> Fmt.pf fmt "%d" (Language.lookup_vector v)]) +[@@deriving show] + type caisar_op = | Classifier of classifier | Dataset of dataset | Data of data | Index of index - | Vector of int + | Vector of vector | Tensor of int [@@deriving show] @@ -53,7 +58,9 @@ let ls_of_caisar_op engine op ty_args ty = (* Option.iter ty ~f:(Fmt.pr "ty: %a@." Pretty.print_ty); *) Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () -> let id = Ident.id_fresh "caisar_op" in - let ls = Term.create_lsymbol id ty_args ty in + let ls = + match op with Vector v -> v | _ -> Term.create_lsymbol id ty_args ty + in (* Fmt.pr "ls: %a@." Pretty.print_ls ls; *) Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls; Term.Hls.add caisar_env.caisar_op_of_ls ls op; @@ -131,7 +138,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = Term.t_int_const (BigInt.of_int (Int.of_string label)) ) in term (Term.t_tuple [ t_features; t_label ]) - | Vector n -> + | Vector v -> + let n = Language.lookup_vector v in assert (List.length tl1 = n && i <= n); term (List.nth_exn tl1 i) | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) @@ -152,7 +160,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = | Data _ | Classifier _ | Tensor _ | Vector _ | Index _ -> assert false) | [ Term { t_node = Tapp (ls, tl); _ } ] -> ( match caisar_op_of_ls engine ls with - | Vector n -> + | Vector v -> + let n = Language.lookup_vector v in assert (List.length tl = n); int (BigInt.of_int n) | Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) @@ -173,7 +182,8 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ] -> ( (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with - | Vector n, Data (D_csv data) -> + | Vector v, Data (D_csv data) -> + let n = Language.lookup_vector v in assert (n = List.length data); let ty_cst = match ty with @@ -185,17 +195,18 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = let cst = const_real_of_float (Float.of_string d) in Term.t_const cst ty_cst) in - let minus = - (* TODO: generalize wrt the type of constants [csts]. *) - let { env; _ } = CRE.user_env engine in - let th = Env.read_theory env [ "ieee_float" ] "Float64" in - Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ]) - in + let { env; _ } = CRE.user_env engine in let args = + let minus = + (* TODO: generalize wrt the type of constants [csts]. *) + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ]) + in List.map2_exn tl1 csts ~f:(fun tl c -> (Term.t_app_infer minus [ tl; c ], ty_cst)) in - term (term_of_caisar_op ~args engine (Vector n) ty) + let caisar_op = Vector (Language.create_vector env n) in + term (term_of_caisar_op ~args engine caisar_op ty) | _ -> assert false) | [ Term t1; Term t2 ] -> (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) @@ -215,14 +226,19 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = assert (Term.t_is_lambda t2); (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) match caisar_op_of_ls engine ls1 with - | Vector n -> + | Vector v -> + let n = Language.lookup_vector v in assert (List.length tl1 = n); let args = List.mapi tl1 ~f:(fun idx t -> let idx = Term.t_int_const (BigInt.of_int idx) in (Term.t_func_app_beta_l t2 [ idx; t ], Option.value_exn t.t_ty)) in - Eval (term_of_caisar_op ~args engine (Vector n) ty) + let caisar_op = + let { env; _ } = CRE.user_env engine in + Vector (Language.create_vector env n) + in + Eval (term_of_caisar_op ~args engine caisar_op ty) | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) | [ Term t1; Term t2 ] -> @@ -399,7 +415,9 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = | { ty_node = Tyapp ({ ts_name; _ }, ty :: _); _ } -> let caisar_op, id = if String.equal ts_name.id_string "vector" - then (Vector n, "caisar_v") + then + let { env; _ } = CRE.user_env engine in + (Vector (Language.create_vector env n), "caisar_v") else if String.equal ts_name.id_string "tensor" then (Tensor n, "caisar_t") else assert false @@ -440,8 +458,12 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = Term.create_vsymbol preid ty) in let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in + let caisar_op = + let { env; _ } = CRE.user_env engine in + Vector (Language.create_vector env n) + in let substitutions = - [ term_of_caisar_op ~args engine (Vector n) (Some vs.vs_ty) ] + [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ] in Some { new_quant; substitutions } | Tapp diff --git a/src/language.ml b/src/language.ml index 88d81e50d98bdeeac4cb7c73a2a16dcde055753e..88990976f94e728f8db55752ab310cdc4c06b680 100644 --- a/src/language.ml +++ b/src/language.ml @@ -155,3 +155,28 @@ let register_onnx_support () = let register_ovo_support () = Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ] (fun env _ filename _ -> ovo_parser env filename) + +type vector = Term.lsymbol + +let vectors = Term.Hls.create 10 + +let create_vector = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module Int) in + let ty_elt = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] + in + let ty = + let th = Env.read_theory env [ "interpretation" ] "Vector" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ] + in + Hashtbl.findi_or_add h ~default:(fun length -> + let id = Ident.id_fresh "vector" in + let ls = + Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty + in + Term.Hls.add vectors ls length; + ls)) + +let lookup_vector = Term.Hls.find vectors diff --git a/src/language.mli b/src/language.mli index eb025a8a9322cfe5d50dc6dbf8663a92956fac8a..a361fd462b39bafded58cd15101cb90d9856db1e 100644 --- a/src/language.mli +++ b/src/language.mli @@ -62,3 +62,8 @@ val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t (* [ovo_parser env filename] parses and creates the theories corresponding to the given ovo [filename]. The result is memoized. *) + +type vector = Term.lsymbol + +val create_vector : Env.env -> int -> vector +val lookup_vector : vector -> int diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t index 4439abbc84b1aa3403d9ef2778c56446fb26e159..253cde00d894e81124efa1c57577d404c62c1993 100644 --- a/tests/interpretation_acasxu.t +++ b/tests/interpretation_acasxu.t @@ -126,15 +126,14 @@ Test interpret on acasxu le (add RNE (mul RNE - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] (373.9499200000000200816430151462554931640625:t)) (7.51888402010059753166615337249822914600372314453125:t)) (1500.0:t) caisar_op, (Interpretation.Classifier "$TESTCASE_ROOT/TestNetwork.nnet") - caisar_op1, + vector, (Interpretation.Vector 5) P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\ @@ -186,35 +185,35 @@ Test interpret on acasxu le (900.0:t) (add RNE (mul RNE caisar_v3 (1100.0:t)) (650.0:t)) -> le (960.0:t) (add RNE (mul RNE caisar_v4 (1200.0:t)) (600.0:t))) -> not (((lt - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2]) /\ lt - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [3]) /\ lt - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op2 - %% caisar_op3 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op1 + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [4]) - caisar_op2, + caisar_op1, (Interpretation.Classifier "$TESTCASE_ROOT/TestNetwork.nnet") - caisar_op3, + vector, (Interpretation.Vector 5) diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 75b13c7c4e8ccbc2125bd8b7c99b4ae559454c6d..4336ff254541537047f71404c605bc1c8a7e681e 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -78,18 +78,14 @@ Test interpret on dataset (0.776470588000000017103729987866245210170745849609375:t)) (0.375:t) -> lt - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1]) /\ (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\ @@ -123,30 +119,26 @@ Test interpret on dataset (0.78431372499999996161790249971090815961360931396484375:t)) (0.375:t) -> lt - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] /\ lt - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op - %% caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) - caisar_op2, + caisar_op1, (Interpretation.Data (Interpretation.D_csv ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) - caisar_op3, + vector, (Interpretation.Vector 5) + caisar_op2, (Interpretation.Data (Interpretation.D_csv ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) - caisar_op1, (Interpretation.Vector 5) caisar_op, (Interpretation.Classifier "$TESTCASE_ROOT/TestNetwork.nnet") - caisar_op4, + caisar_op3, (Interpretation.Dataset <csv>)