diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 9a5f1b845c8ebee477268074593c756ec15010fe..ceeabb2ab8d22ee1bc7e587fd8583325d3f5d9bf 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -57,10 +57,22 @@ let create_new_nn env input_vars outputs : string = Ir.Nier_simple.Node.gather_int input i) in let cache = Why3.Term.Hterm.create 17 in + let nn_cache = Stdlib.Hashtbl.create 17 in (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed into nodes. *) - let rec convert_old_nn old_nn old_index old_nn_args = - let old_index = Why3.Number.to_small_integer old_index in + let rec convert_old_nn old_nn old_nn_args = + let converted_args = List.map ~f:convert_term old_nn_args in + let id = + ( old_nn.Language.nn_filename, + List.map converted_args ~f:(fun n -> n.Ir.Nier_simple.id) ) + in + match Stdlib.Hashtbl.find_opt nn_cache id with + | None -> + let n = convert_old_nn_aux old_nn converted_args in + Stdlib.Hashtbl.add nn_cache id n; + n + | Some n -> n + and convert_old_nn_aux old_nn converted_args = let old_nn_nier = match Onnx.Simple.parse old_nn.Language.nn_filename with | Error s -> @@ -75,8 +87,7 @@ let create_new_nn env input_vars outputs : string = let input () = (* Regroup the terms into one node *) let node = - IR.Node.create - (Concat { inputs = List.map ~f:convert_term old_nn_args; axis = 0 }) + IR.Node.create (Concat { inputs = converted_args; axis = 0 }) in Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node in @@ -86,6 +97,10 @@ let create_new_nn env input_vars outputs : string = (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |]) out in + out + and convert_old_nn_at_old_index old_nn old_index old_nn_args = + let out = convert_old_nn old_nn old_nn_args in + let old_index = Why3.Number.to_small_integer old_index in Ir.Nier_simple.Node.gather_int out old_index and convert_term term = match Why3.Term.Hterm.find_opt cache term with @@ -148,7 +163,7 @@ let create_new_nn env input_vars outputs : string = match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with | Some ({ nn_nb_inputs; _ } as old_nn), Some vector_length -> assert (nn_nb_inputs = vector_length && vector_length = List.length tl); - convert_old_nn old_nn old_index tl + convert_old_nn_at_old_index old_nn old_index tl | _, _ -> Logging.code_error ~src (fun m -> m "Neural network application without fixed NN or arguments: %a" diff --git a/tests/acasxu.t b/tests/acasxu.t index f7fa6bd0062093c06c3f4c0be67e80922e22481c..2c22b3c7c602306aba93504d7365d2079db47edf 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -901,7 +901,7 @@ Test verify on acasxu caisar_1.onnx has 1 input nodes {'name': '179', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} caisar_2.onnx has 1 input nodes - {'name': '392', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + {'name': '327', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} caisar_3.onnx has 1 input nodes - {'name': '619', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} + {'name': '554', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} 4 files checked