Skip to content
Snippets Groups Projects
Commit 5cd87294 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

Added a NIER to the nnshape type.

parent 94019e40
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ type nn_shape = {
nb_outputs : int;
ty_data : Ty.ty;
filename : string;
nier : Onnx.G.t option;
}
type svm_shape = { nb_inputs : int; nb_classes : int; filename : string }
......@@ -41,7 +42,7 @@ let loaded_svms = Term.Hls.create 10
let lookup_loaded_nets = Term.Hls.find_opt loaded_nets
let lookup_loaded_svms = Term.Hls.find_opt loaded_svms
let register_nn_as_tuple nb_inputs nb_outputs filename env =
let register_nn_as_tuple nb_inputs nb_outputs filename nier env =
let net = Pmodule.read_module env [ "caisar" ] "NN" in
let input_type =
Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) []
......@@ -57,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env =
(Ty.ty_tuple (List.init nb_outputs ~f))
in
Term.Hls.add loaded_nets ls_net_apply
{ filename; nb_inputs; nb_outputs; ty_data = input_type };
{ filename; nb_inputs; nb_outputs; ty_data = input_type; nier };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply))
......@@ -86,13 +87,15 @@ let nnet_parser env _ filename _ =
let model = Nnet.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env
| Ok model ->
register_nn_as_tuple model.n_inputs model.n_outputs filename None env
let onnx_parser env _ filename _ =
let model = Onnx.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok (model,_nier) -> register_nn_as_tuple model.n_inputs model.n_outputs filename env
| Ok (model, nier) ->
register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env
let ovo_parser env _ filename _ =
let model = Ovo.parse filename in
......
......@@ -27,6 +27,7 @@ type nn_shape = {
nb_outputs : int;
ty_data : Ty.ty;
filename : string;
nier : Onnx.G.t option;
}
type svm_shape = { nb_inputs : int; nb_classes : int; filename : string }
......
......@@ -346,14 +346,12 @@ let actual_nn_flow env =
match Language.lookup_loaded_nets ls with
| None -> Term.t_map aux term
| Some nn ->
let nn_file = Unix.realpath nn.filename in
let ty_inputs = nn.ty_data in
let g =
let p = Onnx.parse nn_file in
match p with
| Error s -> Loc.errorm "%s" s
| Ok (_model, nier) -> nier
match nn.nier with
| Some g -> g
| None -> failwith "Error, call this transform only on an ONNX NN."
in
let ty_inputs = nn.ty_data in
let cfg_term =
terms_of_nier g ty_inputs env
(Term.t_var @@ create_var "dummy" 0 ty_inputs vars)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment