From 5cd8729430f35fd2e05c92911654a7645f62deb5 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Wed, 21 Sep 2022 16:48:20 +0200 Subject: [PATCH] Added a NIER to the nnshape type. --- src/language.ml | 11 +++++++---- src/language.mli | 1 + src/transformations/actual_net_apply.ml | 10 ++++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/language.ml b/src/language.ml index ca792ee..6cae58d 100644 --- a/src/language.ml +++ b/src/language.ml @@ -32,6 +32,7 @@ type nn_shape = { nb_outputs : int; ty_data : Ty.ty; filename : string; + nier : Onnx.G.t option; } type svm_shape = { nb_inputs : int; nb_classes : int; filename : string } @@ -41,7 +42,7 @@ let loaded_svms = Term.Hls.create 10 let lookup_loaded_nets = Term.Hls.find_opt loaded_nets let lookup_loaded_svms = Term.Hls.find_opt loaded_svms -let register_nn_as_tuple nb_inputs nb_outputs filename env = +let register_nn_as_tuple nb_inputs nb_outputs filename nier env = let net = Pmodule.read_module env [ "caisar" ] "NN" in let input_type = Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) [] @@ -57,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env = (Ty.ty_tuple (List.init nb_outputs ~f)) in Term.Hls.add loaded_nets ls_net_apply - { filename; nb_inputs; nb_outputs; ty_data = input_type }; + { filename; nb_inputs; nb_outputs; ty_data = input_type; nier }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) @@ -86,13 +87,15 @@ let nnet_parser env _ filename _ = let model = Nnet.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env + | Ok model -> + register_nn_as_tuple model.n_inputs model.n_outputs filename None env let onnx_parser env _ filename _ = let model = Onnx.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok (model,_nier) -> register_nn_as_tuple model.n_inputs model.n_outputs filename env + | Ok (model, nier) -> + register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env let ovo_parser env _ filename _ = let model = Ovo.parse filename in diff --git a/src/language.mli b/src/language.mli index 4d0a1ca..90c8cc3 100644 --- a/src/language.mli +++ b/src/language.mli @@ -27,6 +27,7 @@ type nn_shape = { nb_outputs : int; ty_data : Ty.ty; filename : string; + nier : Onnx.G.t option; } type svm_shape = { nb_inputs : int; nb_classes : int; filename : string } diff --git a/src/transformations/actual_net_apply.ml b/src/transformations/actual_net_apply.ml index d1acce8..d188b5c 100644 --- a/src/transformations/actual_net_apply.ml +++ b/src/transformations/actual_net_apply.ml @@ -346,14 +346,12 @@ let actual_nn_flow env = match Language.lookup_loaded_nets ls with | None -> Term.t_map aux term | Some nn -> - let nn_file = Unix.realpath nn.filename in - let ty_inputs = nn.ty_data in let g = - let p = Onnx.parse nn_file in - match p with - | Error s -> Loc.errorm "%s" s - | Ok (_model, nier) -> nier + match nn.nier with + | Some g -> g + | None -> failwith "Error, call this transform only on an ONNX NN." in + let ty_inputs = nn.ty_data in let cfg_term = terms_of_nier g ty_inputs env (Term.t_var @@ create_var "dummy" 0 ty_inputs vars) -- GitLab