Skip to content
Snippets Groups Projects
Commit 498610f2 authored by Michele Alberti's avatar Michele Alberti
Browse files

Update language to onnx lib new API.

parent 3d4ccfa1
No related branches found
No related tags found
No related merge requests found
...@@ -58,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename nier env = ...@@ -58,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename nier env =
(Ty.ty_tuple (List.init nb_outputs ~f)) (Ty.ty_tuple (List.init nb_outputs ~f))
in in
Term.Hls.add loaded_nets ls_net_apply Term.Hls.add loaded_nets ls_net_apply
{ filename; nb_inputs; nb_outputs; ty_data = input_type; nier }; { filename; nb_inputs; nb_outputs; ty_data = input_type; nier };
let th_uc = let th_uc =
Pmodule.add_pdecl ~vc:false th_uc Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply))
...@@ -87,22 +87,23 @@ let nnet_parser env _ filename _ = ...@@ -87,22 +87,23 @@ let nnet_parser env _ filename _ =
let model = Nnet.parse ~permissive:true filename in let model = Nnet.parse ~permissive:true filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok model -> | Ok { n_inputs; n_outputs; _ } ->
register_nn_as_tuple model.n_inputs model.n_outputs filename None env register_nn_as_tuple n_inputs n_outputs filename None env
let onnx_parser env _ filename _ = let onnx_parser env _ filename _ =
let model = Onnx.parse filename in let model = Onnx.parse filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok (model, nier) -> | Ok { n_inputs; n_outputs; nier } ->
register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env let nier = Result.ok nier (* TODO: Warn about parsing errors? *) in
register_nn_as_tuple n_inputs n_outputs filename nier env
let ovo_parser env _ filename _ = let ovo_parser env _ filename _ =
let model = Ovo.parse filename in let model = Ovo.parse filename in
match model with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok model -> | Ok { n_inputs; n_outputs } ->
register_svm_as_array model.n_inputs model.n_outputs filename env register_svm_as_array n_inputs n_outputs filename env
let register_nnet_support () = let register_nnet_support () =
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment