From 47b4c5b36da46ec5c2648f1bea9b499b5a1ca6d6 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 9 Nov 2022 14:39:13 +0100
Subject: [PATCH] [language] Export onnx_parser and ovo_parset for standalone
 use.

---
 src/language.ml  | 54 +++++++++++++++++++++++++++---------------------
 src/language.mli |  8 +++++++
 2 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/src/language.ml b/src/language.ml
index 5814db4..789d5bc 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -114,29 +114,35 @@ let nnet_parser =
         |> register_nn_as_tuple env n_inputs n_outputs filename
         |> register_nn_as_array env n_inputs n_outputs filename))
 
-let onnx_parser env _ filename _ =
-  let model = Onnx.parse filename in
-  match model with
-  | Error s -> Loc.errorm "%s" s
-  | Ok { n_inputs; n_outputs; nier } ->
-    let nier =
-      match nier with
-      | Error msg ->
-        Logs.warn (fun m ->
-          m "Cannot build network intermediate representation:@ %s" msg);
-        None
-      | Ok nier -> Some nier
-    in
-    Wstdlib.Mstr.empty
-    |> register_nn_as_tuple env n_inputs n_outputs filename ?nier
-    |> register_nn_as_array env n_inputs n_outputs filename ?nier
+let onnx_parser =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module String) in
+    Hashtbl.findi_or_add h ~default:(fun filename ->
+      let model = Onnx.parse filename in
+      match model with
+      | Error s -> Loc.errorm "%s" s
+      | Ok { n_inputs; n_outputs; nier } ->
+        let nier =
+          match nier with
+          | Error msg ->
+            Logs.warn (fun m ->
+              m "Cannot build network intermediate representation:@ %s" msg);
+            None
+          | Ok nier -> Some nier
+        in
+        Wstdlib.Mstr.empty
+        |> register_nn_as_tuple env n_inputs n_outputs filename ?nier
+        |> register_nn_as_array env n_inputs n_outputs filename ?nier))
 
-let ovo_parser env _ filename _ =
-  let model = Ovo.parse filename in
-  match model with
-  | Error s -> Loc.errorm "%s" s
-  | Ok { n_inputs; n_outputs } ->
-    register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty
+let ovo_parser =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module String) in
+    Hashtbl.findi_or_add h ~default:(fun filename ->
+      let model = Ovo.parse filename in
+      match model with
+      | Error s -> Loc.errorm "%s" s
+      | Ok { n_inputs; n_outputs } ->
+        register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty))
 
 let register_nnet_support () =
   Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
@@ -144,8 +150,8 @@ let register_nnet_support () =
 
 let register_onnx_support () =
   Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ]
-    onnx_parser
+    (fun env _ filename _ -> onnx_parser env filename)
 
 let register_ovo_support () =
   Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
-    ovo_parser
+    (fun env _ filename _ -> ovo_parser env filename)
diff --git a/src/language.mli b/src/language.mli
index f46f188..68c7157 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -54,3 +54,11 @@ val register_ovo_support : unit -> unit
 val nnet_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
 (* [nnet_parser env filename] parses and creates the theories corresponding to
    the given nnet [filename]. The result is memoized. *)
+
+val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
+(* [onnx_parser env filename] parses and creates the theories corresponding to
+   the given onnx [filename]. The result is memoized. *)
+
+val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
+(* [nnet_parser env filename] parses and creates the theories corresponding to
+   the given ovo [filename]. The result is memoized. *)
-- 
GitLab