From 8b31794cf197dcc8faf9fa959cbd899adfdebadb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Wed, 9 Nov 2022 13:28:40 +0100 Subject: [PATCH] [language] Export nnet_parser for standalone use. --- src/language.ml | 21 ++++++++++++--------- src/language.mli | 4 ++++ src/verification.ml | 5 ++++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/language.ml b/src/language.ml index b89e65b..5814db4 100644 --- a/src/language.ml +++ b/src/language.ml @@ -102,14 +102,17 @@ let register_svm_as_array env nb_inputs nb_classes filename mstr = in Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr -let nnet_parser env _ filename _ = - let model = Nnet.parse ~permissive:true filename in - match model with - | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs; _ } -> - Wstdlib.Mstr.empty - |> register_nn_as_tuple env n_inputs n_outputs filename - |> register_nn_as_array env n_inputs n_outputs filename +let nnet_parser = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + Hashtbl.findi_or_add h ~default:(fun filename -> + let model = Nnet.parse ~permissive:true filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; _ } -> + Wstdlib.Mstr.empty + |> register_nn_as_tuple env n_inputs n_outputs filename + |> register_nn_as_array env n_inputs n_outputs filename)) let onnx_parser env _ filename _ = let model = Onnx.parse filename in @@ -137,7 +140,7 @@ let ovo_parser env _ filename _ = let register_nnet_support () = Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language - "NNet" [ "nnet" ] nnet_parser + "NNet" [ "nnet" ] (fun env _ filename _ -> nnet_parser env filename) let register_onnx_support () = Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ] diff --git a/src/language.mli b/src/language.mli index 39453b4..f46f188 100644 --- a/src/language.mli +++ b/src/language.mli @@ -50,3 +50,7 @@ val register_onnx_support : unit -> unit val register_ovo_support : unit -> unit (** Register OVO parser. *) + +val nnet_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t +(* [nnet_parser env filename] parses and creates the theories corresponding to + the given nnet [filename]. The result is memoized. *) diff --git a/src/verification.ml b/src/verification.ml index eeef52d..cbe3b35 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -191,11 +191,14 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = Fmt.(option ~none:nop (any " " ++ string)) additional_info) -let mstr_theory_of_json_config env _json_config = +let mstr_theory_of_json_config env (json_config : File.json_config) = + let pmod = Language.nnet_parser env json_config.model in + let th_pmod = (Wstdlib.Mstr.find "NNasArray" pmod).mod_theory in let name = "T" in let th_uc = Theory.create_theory (Ident.id_fresh name) in let th_real = Env.read_theory env [ "real" ] "Real" in let th_uc = Theory.use_export th_uc th_real in + let th_uc = Theory.use_export th_uc th_pmod in Wstdlib.Mstr.singleton name (Theory.close_theory th_uc) let verify ?(debug = false) ?format ~loadpath ?memlimit ?timelimit prover -- GitLab