Skip to content
Snippets Groups Projects
Commit 4a14fc9d authored by François Bobot's avatar François Bobot
Browse files

[NN_native] create only one node by identical nn application

parent 42314c90
No related branches found
No related tags found
No related merge requests found
...@@ -57,10 +57,22 @@ let create_new_nn env input_vars outputs : string = ...@@ -57,10 +57,22 @@ let create_new_nn env input_vars outputs : string =
Ir.Nier_simple.Node.gather_int input i) Ir.Nier_simple.Node.gather_int input i)
in in
let cache = Why3.Term.Hterm.create 17 in let cache = Why3.Term.Hterm.create 17 in
let nn_cache = Stdlib.Hashtbl.create 17 in
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
into nodes. *) into nodes. *)
let rec convert_old_nn old_nn old_index old_nn_args = let rec convert_old_nn old_nn old_nn_args =
let old_index = Why3.Number.to_small_integer old_index in let converted_args = List.map ~f:convert_term old_nn_args in
let id =
( old_nn.Language.nn_filename,
List.map converted_args ~f:(fun n -> n.Ir.Nier_simple.id) )
in
match Stdlib.Hashtbl.find_opt nn_cache id with
| None ->
let n = convert_old_nn_aux old_nn converted_args in
Stdlib.Hashtbl.add nn_cache id n;
n
| Some n -> n
and convert_old_nn_aux old_nn converted_args =
let old_nn_nier = let old_nn_nier =
match Onnx.Simple.parse old_nn.Language.nn_filename with match Onnx.Simple.parse old_nn.Language.nn_filename with
| Error s -> | Error s ->
...@@ -75,8 +87,7 @@ let create_new_nn env input_vars outputs : string = ...@@ -75,8 +87,7 @@ let create_new_nn env input_vars outputs : string =
let input () = let input () =
(* Regroup the terms into one node *) (* Regroup the terms into one node *)
let node = let node =
IR.Node.create IR.Node.create (Concat { inputs = converted_args; axis = 0 })
(Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 })
in in
Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node
in in
...@@ -86,6 +97,10 @@ let create_new_nn env input_vars outputs : string = ...@@ -86,6 +97,10 @@ let create_new_nn env input_vars outputs : string =
(Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |]) (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |])
out out
in in
out
and convert_old_nn_at_old_index old_nn old_index old_nn_args =
let out = convert_old_nn old_nn old_nn_args in
let old_index = Why3.Number.to_small_integer old_index in
Ir.Nier_simple.Node.gather_int out old_index Ir.Nier_simple.Node.gather_int out old_index
and convert_term term = and convert_term term =
match Why3.Term.Hterm.find_opt cache term with match Why3.Term.Hterm.find_opt cache term with
...@@ -148,7 +163,7 @@ let create_new_nn env input_vars outputs : string = ...@@ -148,7 +163,7 @@ let create_new_nn env input_vars outputs : string =
match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with
| Some ({ nn_nb_inputs; _ } as old_nn), Some vector_length -> | Some ({ nn_nb_inputs; _ } as old_nn), Some vector_length ->
assert (nn_nb_inputs = vector_length && vector_length = List.length tl); assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
convert_old_nn old_nn old_index tl convert_old_nn_at_old_index old_nn old_index tl
| _, _ -> | _, _ ->
Logging.code_error ~src (fun m -> Logging.code_error ~src (fun m ->
m "Neural network application without fixed NN or arguments: %a" m "Neural network application without fixed NN or arguments: %a"
......
...@@ -901,7 +901,7 @@ Test verify on acasxu ...@@ -901,7 +901,7 @@ Test verify on acasxu
caisar_1.onnx has 1 input nodes caisar_1.onnx has 1 input nodes
{'name': '179', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} {'name': '179', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
caisar_2.onnx has 1 input nodes caisar_2.onnx has 1 input nodes
{'name': '392', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} {'name': '327', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
caisar_3.onnx has 1 input nodes caisar_3.onnx has 1 input nodes
{'name': '619', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} {'name': '554', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
4 files checked 4 files checked
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment