diff --git a/src/language.ml b/src/language.ml index 5814db488b9e24af90da8dc52644e9ffab8d517c..789d5bc6a2f101bd9e3267504dfd3b562424fa33 100644 --- a/src/language.ml +++ b/src/language.ml @@ -114,29 +114,35 @@ let nnet_parser = |> register_nn_as_tuple env n_inputs n_outputs filename |> register_nn_as_array env n_inputs n_outputs filename)) -let onnx_parser env _ filename _ = - let model = Onnx.parse filename in - match model with - | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs; nier } -> - let nier = - match nier with - | Error msg -> - Logs.warn (fun m -> - m "Cannot build network intermediate representation:@ %s" msg); - None - | Ok nier -> Some nier - in - Wstdlib.Mstr.empty - |> register_nn_as_tuple env n_inputs n_outputs filename ?nier - |> register_nn_as_array env n_inputs n_outputs filename ?nier +let onnx_parser = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + Hashtbl.findi_or_add h ~default:(fun filename -> + let model = Onnx.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; nier } -> + let nier = + match nier with + | Error msg -> + Logs.warn (fun m -> + m "Cannot build network intermediate representation:@ %s" msg); + None + | Ok nier -> Some nier + in + Wstdlib.Mstr.empty + |> register_nn_as_tuple env n_inputs n_outputs filename ?nier + |> register_nn_as_array env n_inputs n_outputs filename ?nier)) -let ovo_parser env _ filename _ = - let model = Ovo.parse filename in - match model with - | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs } -> - register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty +let ovo_parser = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + Hashtbl.findi_or_add h ~default:(fun filename -> + let model = Ovo.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs } -> + register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty)) let register_nnet_support () = Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language @@ -144,8 +150,8 @@ let register_nnet_support () = let register_onnx_support () = Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ] - onnx_parser + (fun env _ filename _ -> onnx_parser env filename) let register_ovo_support () = Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ] - ovo_parser + (fun env _ filename _ -> ovo_parser env filename) diff --git a/src/language.mli b/src/language.mli index f46f1887346aa2cb49350164080f1cb9f83343fc..68c71572b844dc0aff4a1578be3d999950cad000 100644 --- a/src/language.mli +++ b/src/language.mli @@ -54,3 +54,11 @@ val register_ovo_support : unit -> unit val nnet_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t (* [nnet_parser env filename] parses and creates the theories corresponding to the given nnet [filename]. The result is memoized. *) + +val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t +(* [onnx_parser env filename] parses and creates the theories corresponding to + the given onnx [filename]. The result is memoized. *) + +val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t +(* [nnet_parser env filename] parses and creates the theories corresponding to + the given ovo [filename]. The result is memoized. *)