From 4e558db3d07ffc6d2aec8b57463ec7145aac123d Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 5 Apr 2023 16:30:02 +0200 Subject: [PATCH] [interpretation] Add notion of classifier in language and use it in interpretation. --- src/interpretation.ml | 42 +++++++++++---- src/language.ml | 95 +++++++++++++++++++++++++++++++--- src/language.mli | 21 +++++++- tests/interpretation_acasxu.t | 16 +++--- tests/interpretation_dataset.t | 14 +++-- 5 files changed, 160 insertions(+), 28 deletions(-) diff --git a/src/interpretation.ml b/src/interpretation.ml index 238359d..d812197 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -24,16 +24,32 @@ module CRE = Reduction_engine (* Caisar Reduction Engine *) open Why3 open Base +type classifier = + | NNet of Language.nn_classifier + [@printer + fun fmt nn -> + Fmt.pf fmt "NNet: %a" + Fmt.(option Language.pp_nn) + (Language.lookup_nn_classifier nn)] + | ONNX of Language.nn_classifier + [@printer + fun fmt nn -> + Fmt.pf fmt "ONNX: %a" + Fmt.(option Language.pp_nn) + (Language.lookup_nn_classifier nn)] +[@@deriving show] + type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"] [@@deriving show] -type classifier = string [@@deriving show] type data = D_csv of string list [@@deriving show] type index = I_csv of int [@@deriving show] type vector = (Language.vector - [@printer fun fmt v -> Fmt.pf fmt "%d" (Language.lookup_vector v)]) + [@printer + fun fmt v -> + Fmt.pf fmt "%a" Fmt.(option ~none:nop int) (Language.lookup_vector v)]) [@@deriving show] type caisar_op = @@ -139,7 +155,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = in term (Term.t_tuple [ t_features; t_label ]) | Vector v -> - let n = Language.lookup_vector v in + let n = Option.value_exn (Language.lookup_vector v) in assert (List.length tl1 = n && i <= n); term (List.nth_exn tl1 i) | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) @@ -157,13 +173,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = | [ Term { t_node = Tapp (ls, []); _ } ] -> ( match caisar_op_of_ls engine ls with | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) - | Vector v -> int (BigInt.of_int (Language.lookup_vector v)) + | Vector v -> + int (BigInt.of_int (Option.value_exn (Language.lookup_vector v))) | Data (D_csv data) -> int (BigInt.of_int (List.length data)) | Classifier _ | Tensor _ | Index _ -> assert false) | [ Term { t_node = Tapp (ls, tl); _ } ] -> ( match caisar_op_of_ls engine ls with | Vector v -> - let n = Language.lookup_vector v in + let n = Option.value_exn (Language.lookup_vector v) in assert (List.length tl = n); int (BigInt.of_int n) | Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) @@ -185,7 +202,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with | Vector v, Data (D_csv data) -> - let n = Language.lookup_vector v in + let n = Option.value_exn (Language.lookup_vector v) in assert (n = List.length data); let ty_cst = match ty with @@ -229,7 +246,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) match caisar_op_of_ls engine ls1 with | Vector v -> - let n = Language.lookup_vector v in + let n = Option.value_exn (Language.lookup_vector v) in assert (List.length tl1 = n); let args = List.mapi tl1 ~f:(fun idx t -> @@ -322,12 +339,17 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = match vl with | [ Term { t_node = Tconst (ConstStr classifier); _ }; - Term { t_node = Tapp ({ ls_name = { id_string = "NNet"; _ }; _ }, []); _ }; + Term { t_node = Tapp ({ ls_name = { id_string; _ }; _ }, []); _ }; ] -> - let cwd = (CRE.user_env engine).cwd in + let { env; cwd; _ } = CRE.user_env engine in let caisar_op = let filename = Caml.Filename.concat cwd classifier in - Classifier filename + let classifier = + if String.equal id_string "NNet" + then NNet (Language.create_nnet_classifier env filename) + else ONNX (Language.create_onnx_classifier env filename) + in + Classifier classifier in term (term_of_caisar_op engine caisar_op ty) | _ -> invalid_arg (error_message ls) diff --git a/src/language.ml b/src/language.ml index 8899097..3499f00 100644 --- a/src/language.ml +++ b/src/language.ml @@ -156,27 +156,110 @@ let register_ovo_support () = Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ] (fun env _ filename _ -> ovo_parser env filename) +(* -- Vector *) + type vector = Term.lsymbol let vectors = Term.Hls.create 10 +let vector_elt_ty env = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] + let create_vector = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module Int) in - let ty_elt = - let th = Env.read_theory env [ "ieee_float" ] "Float64" in - Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] - in + let ty_elt = vector_elt_ty env in let ty = let th = Env.read_theory env [ "interpretation" ] "Vector" in Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ] in Hashtbl.findi_or_add h ~default:(fun length -> - let id = Ident.id_fresh "vector" in let ls = + let id = Ident.id_fresh "vector" in Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty in Term.Hls.add vectors ls length; ls)) -let lookup_vector = Term.Hls.find vectors +let lookup_vector = Term.Hls.find_opt vectors + +(* -- Classifier *) + +type nn = { + nn_inputs : int; + nn_outputs : int; + nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty] + nn_filename : string; + nn_nier : Onnx.G.t option; [@opaque] +} +[@@deriving show] + +type nn_classifier = Term.lsymbol + +let nn_classifiers = Term.Hls.create 10 + +let fresh_classifier_ls env name = + let ty = + let th = Env.read_theory env [ "interpretation" ] "Classifier" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "classifier" ]) [] + in + let id = Ident.id_fresh name in + Term.create_fsymbol id [] ty + +let create_nnet_classifier = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + let ty_elt = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] + in + Hashtbl.findi_or_add h ~default:(fun filename -> + let ls = fresh_classifier_ls env "nnet_classifier" in + let nn = + let model = Nnet.parse ~permissive:true filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; _ } -> + { + nn_inputs = n_inputs; + nn_outputs = n_outputs; + nn_ty_elt = ty_elt; + nn_filename = filename; + nn_nier = None; + } + in + Term.Hls.add nn_classifiers ls nn; + ls)) + +let create_onnx_classifier = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + let ty_elt = vector_elt_ty env in + Hashtbl.findi_or_add h ~default:(fun filename -> + let ls = fresh_classifier_ls env "onnx_classifier" in + let onnx = + let model = Onnx.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; nier } -> + let nier = + match nier with + | Error msg -> + Logs.warn (fun m -> + m "Cannot build network intermediate representation:@ %s" msg); + None + | Ok nier -> Some nier + in + { + nn_inputs = n_inputs; + nn_outputs = n_outputs; + nn_ty_elt = ty_elt; + nn_filename = filename; + nn_nier = nier; + } + in + Term.Hls.add nn_classifiers ls onnx; + ls)) + +let lookup_nn_classifier = Term.Hls.find_opt nn_classifiers diff --git a/src/language.mli b/src/language.mli index a361fd4..afa4a11 100644 --- a/src/language.mli +++ b/src/language.mli @@ -63,7 +63,26 @@ val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t (* [ovo_parser env filename] parses and creates the theories corresponding to the given ovo [filename]. The result is memoized. *) +(** -- Vector *) + type vector = Term.lsymbol val create_vector : Env.env -> int -> vector -val lookup_vector : vector -> int +val lookup_vector : vector -> int option + +(** -- Classifier *) + +type nn = private { + nn_inputs : int; + nn_outputs : int; + nn_ty_elt : Ty.ty; + nn_filename : string; + nn_nier : Onnx.G.t option; +} +[@@deriving show] + +type nn_classifier = Term.lsymbol + +val create_nnet_classifier : Env.env -> string -> nn_classifier +val create_onnx_classifier : Env.env -> string -> nn_classifier +val lookup_nn_classifier : nn_classifier -> nn option diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t index 253cde0..9d1adea 100644 --- a/tests/interpretation_acasxu.t +++ b/tests/interpretation_acasxu.t @@ -130,11 +130,13 @@ Test interpret on acasxu [0] (373.9499200000000200816430151462554931640625:t)) (7.51888402010059753166615337249822914600372314453125:t)) (1500.0:t) + vector, (Interpretation.Vector 5) caisar_op, (Interpretation.Classifier - "$TESTCASE_ROOT/TestNetwork.nnet") - vector, - (Interpretation.Vector 5) + NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; + nn_filename = + "$TESTCASE_ROOT/TestNetwork.nnet"; + nn_nier = <opaque> }) P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\ le (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) (60760.0:t) -> @@ -212,8 +214,10 @@ Test interpret on acasxu (caisar_op1 %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [4]) + vector, (Interpretation.Vector 5) caisar_op1, (Interpretation.Classifier - "$TESTCASE_ROOT/TestNetwork.nnet") - vector, - (Interpretation.Vector 5) + NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; + nn_filename = + "$TESTCASE_ROOT/TestNetwork.nnet"; + nn_nier = <opaque> }) diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 2fd6c14..3ff6693 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -129,15 +129,19 @@ Test interpret on dataset (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) caisar_op1, + (Interpretation.Data + (Interpretation.D_csv + ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) + caisar_op2, (Interpretation.Data (Interpretation.D_csv ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) vector, (Interpretation.Vector 5) caisar_op, (Interpretation.Classifier - "$TESTCASE_ROOT/TestNetwork.nnet") - caisar_op2, (Interpretation.Dataset <csv>) + NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; + nn_filename = + "$TESTCASE_ROOT/TestNetwork.nnet"; + nn_nier = <opaque> }) caisar_op3, - (Interpretation.Data - (Interpretation.D_csv - ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) + (Interpretation.Dataset <csv>) -- GitLab