From 47b4c5b36da46ec5c2648f1bea9b499b5a1ca6d6 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Wed, 9 Nov 2022 14:39:13 +0100 Subject: [PATCH] [language] Export onnx_parser and ovo_parset for standalone use. --- src/language.ml | 54 +++++++++++++++++++++++++++--------------------- src/language.mli | 8 +++++++ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/language.ml b/src/language.ml index 5814db4..789d5bc 100644 --- a/src/language.ml +++ b/src/language.ml @@ -114,29 +114,35 @@ let nnet_parser = |> register_nn_as_tuple env n_inputs n_outputs filename |> register_nn_as_array env n_inputs n_outputs filename)) -let onnx_parser env _ filename _ = - let model = Onnx.parse filename in - match model with - | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs; nier } -> - let nier = - match nier with - | Error msg -> - Logs.warn (fun m -> - m "Cannot build network intermediate representation:@ %s" msg); - None - | Ok nier -> Some nier - in - Wstdlib.Mstr.empty - |> register_nn_as_tuple env n_inputs n_outputs filename ?nier - |> register_nn_as_array env n_inputs n_outputs filename ?nier +let onnx_parser = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + Hashtbl.findi_or_add h ~default:(fun filename -> + let model = Onnx.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; nier } -> + let nier = + match nier with + | Error msg -> + Logs.warn (fun m -> + m "Cannot build network intermediate representation:@ %s" msg); + None + | Ok nier -> Some nier + in + Wstdlib.Mstr.empty + |> register_nn_as_tuple env n_inputs n_outputs filename ?nier + |> register_nn_as_array env n_inputs n_outputs filename ?nier)) -let ovo_parser env _ filename _ = - let model = Ovo.parse filename in - match model with - | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs } -> - register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty +let ovo_parser = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + Hashtbl.findi_or_add h ~default:(fun filename -> + let model = Ovo.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs } -> + register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty)) let register_nnet_support () = Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language @@ -144,8 +150,8 @@ let register_nnet_support () = let register_onnx_support () = Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ] - onnx_parser + (fun env _ filename _ -> onnx_parser env filename) let register_ovo_support () = Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ] - ovo_parser + (fun env _ filename _ -> ovo_parser env filename) diff --git a/src/language.mli b/src/language.mli index f46f188..68c7157 100644 --- a/src/language.mli +++ b/src/language.mli @@ -54,3 +54,11 @@ val register_ovo_support : unit -> unit val nnet_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t (* [nnet_parser env filename] parses and creates the theories corresponding to the given nnet [filename]. The result is memoized. *) + +val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t +(* [onnx_parser env filename] parses and creates the theories corresponding to + the given onnx [filename]. The result is memoized. *) + +val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t +(* [nnet_parser env filename] parses and creates the theories corresponding to + the given ovo [filename]. The result is memoized. *) -- GitLab