From 8b31794cf197dcc8faf9fa959cbd899adfdebadb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Wed, 9 Nov 2022 13:28:40 +0100
Subject: [PATCH] [language] Export nnet_parser for standalone use.

---
 src/language.ml     | 21 ++++++++++++---------
 src/language.mli    |  4 ++++
 src/verification.ml |  5 ++++-
 3 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/src/language.ml b/src/language.ml
index b89e65b..5814db4 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -102,14 +102,17 @@ let register_svm_as_array env nb_inputs nb_classes filename mstr =
   in
   Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr
 
-let nnet_parser env _ filename _ =
-  let model = Nnet.parse ~permissive:true filename in
-  match model with
-  | Error s -> Loc.errorm "%s" s
-  | Ok { n_inputs; n_outputs; _ } ->
-    Wstdlib.Mstr.empty
-    |> register_nn_as_tuple env n_inputs n_outputs filename
-    |> register_nn_as_array env n_inputs n_outputs filename
+let nnet_parser =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module String) in
+    Hashtbl.findi_or_add h ~default:(fun filename ->
+      let model = Nnet.parse ~permissive:true filename in
+      match model with
+      | Error s -> Loc.errorm "%s" s
+      | Ok { n_inputs; n_outputs; _ } ->
+        Wstdlib.Mstr.empty
+        |> register_nn_as_tuple env n_inputs n_outputs filename
+        |> register_nn_as_array env n_inputs n_outputs filename))
 
 let onnx_parser env _ filename _ =
   let model = Onnx.parse filename in
@@ -137,7 +140,7 @@ let ovo_parser env _ filename _ =
 
 let register_nnet_support () =
   Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
-    "NNet" [ "nnet" ] nnet_parser
+    "NNet" [ "nnet" ] (fun env _ filename _ -> nnet_parser env filename)
 
 let register_onnx_support () =
   Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ]
diff --git a/src/language.mli b/src/language.mli
index 39453b4..f46f188 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -50,3 +50,7 @@ val register_onnx_support : unit -> unit
 
 val register_ovo_support : unit -> unit
 (** Register OVO parser. *)
+
+val nnet_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
+(* [nnet_parser env filename] parses and creates the theories corresponding to
+   the given nnet [filename]. The result is memoized. *)
diff --git a/src/verification.ml b/src/verification.ml
index eeef52d..cbe3b35 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -191,11 +191,14 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
       Fmt.(option ~none:nop (any " " ++ string))
       additional_info)
 
-let mstr_theory_of_json_config env _json_config =
+let mstr_theory_of_json_config env (json_config : File.json_config) =
+  let pmod = Language.nnet_parser env json_config.model in
+  let th_pmod = (Wstdlib.Mstr.find "NNasArray" pmod).mod_theory in
   let name = "T" in
   let th_uc = Theory.create_theory (Ident.id_fresh name) in
   let th_real = Env.read_theory env [ "real" ] "Real" in
   let th_uc = Theory.use_export th_uc th_real in
+  let th_uc = Theory.use_export th_uc th_pmod in
   Wstdlib.Mstr.singleton name (Theory.close_theory th_uc)
 
 let verify ?(debug = false) ?format ~loadpath ?memlimit ?timelimit prover
-- 
GitLab