From d01b778a328371086373e04b71454bfd5c9d4a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Tue, 23 Apr 2024 11:58:34 +0200 Subject: [PATCH] [nn_native] convert vector term to node with shape of arbitrary size --- src/interpretation/interpreter_theory.ml | 41 +++--- src/transformations/native_nn_prover.ml | 172 +++++++++++------------ tests/acasxu_ci.t | 6 +- tests/interpretation_fail.t | 2 +- 4 files changed, 112 insertions(+), 109 deletions(-) diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml index c235c43..6f079ff 100644 --- a/src/interpretation/interpreter_theory.ml +++ b/src/interpretation/interpreter_theory.ml @@ -385,25 +385,28 @@ module Model = struct () (* Logging.user_error ?loc:t1.t_loc (fun m -> m "Unexpected neural network model application: %a" Why3.Pretty.print_term t2) *)); - let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in - let get = - Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ] - in - let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in - let args = - List.init nb_outputs ~f:(fun i -> - ( Why3.Term.fs_app get - [ - t0; - Why3.Term.t_const - (Why3.Constant.int_const_of_int i) - Why3.Ty.ty_int; - ] - ty_elt, - ty_elt )) - in - let op = ITypes.Vector (Language.create_vector env nb_outputs) in - IRE.value_term (ITypes.term_of_op ~args engine op ty) + if false + then + let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in + let get = + Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ] + in + let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in + let args = + List.init nb_outputs ~f:(fun i -> + ( Why3.Term.fs_app get + [ + t0; + Why3.Term.t_const + (Why3.Constant.int_const_of_int i) + Why3.Ty.ty_int; + ] + ty_elt, + ty_elt )) + in + let op = ITypes.Vector (Language.create_vector env nb_outputs) in + IRE.value_term (ITypes.term_of_op ~args engine op ty) + else IRE.reconstruct_term () | _ -> IRE.reconstruct_term ()) | _ -> fail_on_unexpected_argument ls diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index e1c698d..f3c364b 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -72,10 +72,7 @@ let match_nn_app th_model term = assert (nn_nb_inputs = vector_length && vector_length = List.length tl); let c = Why3.Number.to_small_integer c in Some (nn, tl, c) - | _, _ -> - Logging.code_error ~src (fun m -> - m "Neural network application without fixed NN or arguments: %a" - Why3.Pretty.print_term term)) + | _, _ -> None) | _ -> None let create_new_nn env input_vars outputs : string = @@ -93,22 +90,14 @@ let create_new_nn env input_vars outputs : string = IR.Node.gather_int input i) in let cache = Why3.Term.Hterm.create 17 in - let nn_cache = Stdlib.Hashtbl.create 17 in - (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed - into nodes. *) - let rec convert_old_nn old_nn old_nn_args = - let converted_args = List.map ~f:convert_term old_nn_args in - let id = - ( old_nn.Language.nn_filename, - List.map converted_args ~f:(fun n -> n.Nir.Node.id) ) - in - match Stdlib.Hashtbl.find_opt nn_cache id with - | None -> - let node_nn = convert_old_nn_aux old_nn converted_args in - Stdlib.Hashtbl.add nn_cache id node_nn; - node_nn - | Some node_nn -> node_nn - and convert_old_nn_aux old_nn converted_args = + (* Instantiate the input of [old_nn] with the [converted_arg] term transformed + into a node. *) + let convert_nn ?loc old_nn converted_arg = + if old_nn.Language.nn_nb_inputs + <> Nir.Shape.size converted_arg.Nir.Node.shape + then + Logging.user_error ?loc (fun m -> + m "Neural network applied with the wrong number of arguments"); let old_nn_nir = match Onnx.Reader.from_file old_nn.Language.nn_filename with | Error s -> @@ -121,76 +110,88 @@ let create_new_nn env input_vars outputs : string = in (* Create the graph to replace the old input of the nn *) let input () = - (* Regroup the terms into one node *) - let node = - IR.Node.create (Concat { inputs = converted_args; axis = 0 }) - in - IR.Node.reshape (IR.Ngraph.input_shape old_nn_nir) node + let node = converted_arg in + Nir.Node.reshape (Nir.Ngraph.input_shape old_nn_nir) node in let out = - IR.Node.replace_input input (IR.Ngraph.output old_nn_nir) + IR.Node.replace_input input (Nir.Ngraph.output old_nn_nir) |> IR.Node.reshape (Nir.Shape.of_array [| old_nn.nn_nb_outputs |]) in out - and convert_old_nn_at_old_index old_nn old_index old_nn_args = - let out = convert_old_nn old_nn old_nn_args in - Nir.Node.gather_int out old_index - and convert_term term = + in + let rec convert_term term = match Why3.Term.Hterm.find_opt cache term with | None -> let n = convert_term_aux term in Why3.Term.Hterm.add cache term n; n | Some n -> n - and convert_term_aux term : IR.Node.t = - if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty) + and convert_term_aux term : Nir.Node.t = + let loc = term.Why3.Term.t_loc in + if false + && not + (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty) then Logging.user_error ?loc:term.t_loc (fun m -> m "Cannot convert non Float64 term %a" Why3.Pretty.print_term term); - match match_nn_app th_model term with - | Some (old_nn, tl, old_index) -> - convert_old_nn_at_old_index old_nn old_index tl - | None -> ( - match term.Why3.Term.t_node with + match term.Why3.Term.t_node with + | Tapp (ls_get (* [ ] *), [ t1; { t_node = Tconst (ConstInt c); _ } ]) + when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") -> + let c = Why3.Number.to_small_integer c in + let t1 = convert_term t1 in + if Nir.Shape.size t1.Nir.Node.shape <= c + then + Logging.user_error ?loc (fun m -> + m "Vector accessed outside its bounds"); + Nir.Node.gather_int t1 c + | Tapp (ls_atat (* @@ *), [ { t_node = Tapp (ls_nn (* nn *), _); _ }; t2 ]) + when Why3.Term.ls_equal ls_atat th_model.Symbols.Model.atat -> ( + match Language.lookup_nn ls_nn with + | Some nn -> convert_nn ?loc nn (convert_term t2) + | _ -> + Logging.code_error ~src (fun m -> + m "Neural network application without fixed NN: %a" + Why3.Pretty.print_term term)) + | Tapp (ls (* input vector *), tl (* input vars *)) + when Language.mem_vector ls -> + let inputs = List.map tl ~f:convert_term in + IR.Node.create (Concat { inputs; axis = 0 }) + | Tconst (Why3.Constant.ConstReal r) -> + IR.Node.create + (Constant + { + data = + Nir.Gentensor.create_1_float (Utils.float_of_real_constant r); + }) + | Tapp (ls, []) -> get_input ls + | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add -> + IR.Node.create (Add { input1 = convert_term a; input2 = convert_term b }) + | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub -> + IR.Node.create (Sub { input1 = convert_term a; input2 = convert_term b }) + | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul -> + IR.Node.create (Mul { input1 = convert_term a; input2 = convert_term b }) + | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> ( + match b.t_node with | Tconst (Why3.Constant.ConstReal r) -> + let f = Utils.float_of_real_constant r in + Nir.Node.div_float (convert_term a) f + | _ -> IR.Node.create - (Constant - { - data = - IR.Gentensor.create_1_float (Utils.float_of_real_constant r); - }) - | Tapp (ls, []) -> get_input ls - | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add -> - IR.Node.create - (Add { input1 = convert_term a; input2 = convert_term b }) - | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub -> - IR.Node.create - (Sub { input1 = convert_term a; input2 = convert_term b }) - | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul -> - IR.Node.create - (Mul { input1 = convert_term a; input2 = convert_term b }) - | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> ( - match b.t_node with - | Tconst (Why3.Constant.ConstReal r) -> - let f = Utils.float_of_real_constant r in - Nir.Node.div_float (convert_term a) f - | _ -> - IR.Node.create - (Div { input1 = convert_term a; input2 = convert_term b })) - | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg -> - Nir.Node.mul_float (convert_term a) (-1.) - | Tconst _ - | Tapp (_, _) - | Tif (_, _, _) - | Tlet (_, _) - | Tbinop (_, _, _) - | Tcase (_, _) - | Tnot _ | Ttrue | Tfalse -> - Logging.not_implemented_yet (fun m -> - m "Why3 term to IR: %a" Why3.Pretty.print_term term) - | Tvar _ | Teps _ | Tquant (_, _) -> - Logging.not_implemented_yet (fun m -> - m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term)) + (Div { input1 = convert_term a; input2 = convert_term b })) + | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg -> + Nir.Node.mul_float (convert_term a) (-1.) + | Tconst _ + | Tapp (_, _) + | Tif (_, _, _) + | Tlet (_, _) + | Tbinop (_, _, _) + | Tcase (_, _) + | Tnot _ | Ttrue | Tfalse -> + Logging.not_implemented_yet (fun m -> + m "Why3 term to IR: %a" Why3.Pretty.print_term term) + | Tvar _ | Teps _ | Tquant (_, _) -> + Logging.not_implemented_yet (fun m -> + m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term) in (* Start actual term conversion. *) let outputs = @@ -259,19 +260,18 @@ let check_if_new_nn_needed env input_vars outputs = (* Choose the term pattern for starting the conversion to ONNX. *) let has_start_pattern env term = - let th_model = Symbols.Model.create env in let th_f = Symbols.Float64.create env in - match match_nn_app th_model term with - | Some _ -> true - | None -> ( - match term.Why3.Term.t_node with - | Tapp ({ ls_value = Some ty; _ }, []) -> - (* input symbol *) Why3.Ty.ty_equal ty th_f.ty - | Tapp (ls_app, _) -> - List.mem - [ th_f.add; th_f.sub; th_f.mul; th_f.div ] - ~equal:Why3.Term.ls_equal ls_app - | _ -> false) + match term.Why3.Term.t_node with + | Tapp (ls_get (* [ ] *), _) + when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") -> + true + | Tapp ({ ls_value = Some ty; _ }, []) -> + (* input symbol *) Why3.Ty.ty_equal ty th_f.ty + | Tapp (ls_app, _) -> + List.mem + [ th_f.add; th_f.sub; th_f.mul; th_f.div ] + ~equal:Why3.Term.ls_equal ls_app + | _ -> false (* Abstract terms that contains neural network applications. *) let abstract_nn_term env = diff --git a/tests/acasxu_ci.t b/tests/acasxu_ci.t index 72468ad..f274c7e 100644 --- a/tests/acasxu_ci.t +++ b/tests/acasxu_ci.t @@ -936,9 +936,9 @@ Test verify on acasxu out/caisar_1.onnx has 1 input nodes {'name': '221', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} out/caisar_2.onnx has 1 input nodes - {'name': '440', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} + {'name': '439', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} out/caisar_3.onnx has 1 input nodes - {'name': '588', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} + {'name': '587', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}} out/caisar_4.onnx has 1 input nodes - {'name': '815', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} + {'name': '814', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}} 5 files checked diff --git a/tests/interpretation_fail.t b/tests/interpretation_fail.t index 72ca477..bd06a54 100644 --- a/tests/interpretation_fail.t +++ b/tests/interpretation_fail.t @@ -146,7 +146,7 @@ Test interpret fail > EOF $ caisar verify --prover nnenum file.mlw - [ERROR] "file.mlw", line 12, characters 14-23: + [ERROR] "file.mlw", line 12, characters 24-26: Index constant 10 is out-of-bounds [0,4] $ cat > file.mlw <<EOF -- GitLab