Skip to content
Snippets Groups Projects
Commit d01b778a authored by François Bobot's avatar François Bobot Committed by Michele Alberti
Browse files

[nn_native] convert vector term to node with shape of arbitrary size

parent 9817e8e9
No related branches found
No related tags found
No related merge requests found
......@@ -385,25 +385,28 @@ module Model = struct
()
(* Logging.user_error ?loc:t1.t_loc (fun m -> m "Unexpected neural
network model application: %a" Why3.Pretty.print_term t2) *));
let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in
let get =
Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ]
in
let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in
let args =
List.init nb_outputs ~f:(fun i ->
( Why3.Term.fs_app get
[
t0;
Why3.Term.t_const
(Why3.Constant.int_const_of_int i)
Why3.Ty.ty_int;
]
ty_elt,
ty_elt ))
in
let op = ITypes.Vector (Language.create_vector env nb_outputs) in
IRE.value_term (ITypes.term_of_op ~args engine op ty)
if false
then
let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in
let get =
Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ]
in
let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in
let args =
List.init nb_outputs ~f:(fun i ->
( Why3.Term.fs_app get
[
t0;
Why3.Term.t_const
(Why3.Constant.int_const_of_int i)
Why3.Ty.ty_int;
]
ty_elt,
ty_elt ))
in
let op = ITypes.Vector (Language.create_vector env nb_outputs) in
IRE.value_term (ITypes.term_of_op ~args engine op ty)
else IRE.reconstruct_term ()
| _ -> IRE.reconstruct_term ())
| _ -> fail_on_unexpected_argument ls
......
......@@ -72,10 +72,7 @@ let match_nn_app th_model term =
assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
let c = Why3.Number.to_small_integer c in
Some (nn, tl, c)
| _, _ ->
Logging.code_error ~src (fun m ->
m "Neural network application without fixed NN or arguments: %a"
Why3.Pretty.print_term term))
| _, _ -> None)
| _ -> None
let create_new_nn env input_vars outputs : string =
......@@ -93,22 +90,14 @@ let create_new_nn env input_vars outputs : string =
IR.Node.gather_int input i)
in
let cache = Why3.Term.Hterm.create 17 in
let nn_cache = Stdlib.Hashtbl.create 17 in
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
into nodes. *)
let rec convert_old_nn old_nn old_nn_args =
let converted_args = List.map ~f:convert_term old_nn_args in
let id =
( old_nn.Language.nn_filename,
List.map converted_args ~f:(fun n -> n.Nir.Node.id) )
in
match Stdlib.Hashtbl.find_opt nn_cache id with
| None ->
let node_nn = convert_old_nn_aux old_nn converted_args in
Stdlib.Hashtbl.add nn_cache id node_nn;
node_nn
| Some node_nn -> node_nn
and convert_old_nn_aux old_nn converted_args =
(* Instantiate the input of [old_nn] with the [converted_arg] term transformed
into a node. *)
let convert_nn ?loc old_nn converted_arg =
if old_nn.Language.nn_nb_inputs
<> Nir.Shape.size converted_arg.Nir.Node.shape
then
Logging.user_error ?loc (fun m ->
m "Neural network applied with the wrong number of arguments");
let old_nn_nir =
match Onnx.Reader.from_file old_nn.Language.nn_filename with
| Error s ->
......@@ -121,76 +110,88 @@ let create_new_nn env input_vars outputs : string =
in
(* Create the graph to replace the old input of the nn *)
let input () =
(* Regroup the terms into one node *)
let node =
IR.Node.create (Concat { inputs = converted_args; axis = 0 })
in
IR.Node.reshape (IR.Ngraph.input_shape old_nn_nir) node
let node = converted_arg in
Nir.Node.reshape (Nir.Ngraph.input_shape old_nn_nir) node
in
let out =
IR.Node.replace_input input (IR.Ngraph.output old_nn_nir)
IR.Node.replace_input input (Nir.Ngraph.output old_nn_nir)
|> IR.Node.reshape (Nir.Shape.of_array [| old_nn.nn_nb_outputs |])
in
out
and convert_old_nn_at_old_index old_nn old_index old_nn_args =
let out = convert_old_nn old_nn old_nn_args in
Nir.Node.gather_int out old_index
and convert_term term =
in
let rec convert_term term =
match Why3.Term.Hterm.find_opt cache term with
| None ->
let n = convert_term_aux term in
Why3.Term.Hterm.add cache term n;
n
| Some n -> n
and convert_term_aux term : IR.Node.t =
if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
and convert_term_aux term : Nir.Node.t =
let loc = term.Why3.Term.t_loc in
if false
&& not
(Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
then
Logging.user_error ?loc:term.t_loc (fun m ->
m "Cannot convert non Float64 term %a" Why3.Pretty.print_term term);
match match_nn_app th_model term with
| Some (old_nn, tl, old_index) ->
convert_old_nn_at_old_index old_nn old_index tl
| None -> (
match term.Why3.Term.t_node with
match term.Why3.Term.t_node with
| Tapp (ls_get (* [ ] *), [ t1; { t_node = Tconst (ConstInt c); _ } ])
when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") ->
let c = Why3.Number.to_small_integer c in
let t1 = convert_term t1 in
if Nir.Shape.size t1.Nir.Node.shape <= c
then
Logging.user_error ?loc (fun m ->
m "Vector accessed outside its bounds");
Nir.Node.gather_int t1 c
| Tapp (ls_atat (* @@ *), [ { t_node = Tapp (ls_nn (* nn *), _); _ }; t2 ])
when Why3.Term.ls_equal ls_atat th_model.Symbols.Model.atat -> (
match Language.lookup_nn ls_nn with
| Some nn -> convert_nn ?loc nn (convert_term t2)
| _ ->
Logging.code_error ~src (fun m ->
m "Neural network application without fixed NN: %a"
Why3.Pretty.print_term term))
| Tapp (ls (* input vector *), tl (* input vars *))
when Language.mem_vector ls ->
let inputs = List.map tl ~f:convert_term in
IR.Node.create (Concat { inputs; axis = 0 })
| Tconst (Why3.Constant.ConstReal r) ->
IR.Node.create
(Constant
{
data =
Nir.Gentensor.create_1_float (Utils.float_of_real_constant r);
})
| Tapp (ls, []) -> get_input ls
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add ->
IR.Node.create (Add { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub ->
IR.Node.create (Sub { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul ->
IR.Node.create (Mul { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> (
match b.t_node with
| Tconst (Why3.Constant.ConstReal r) ->
let f = Utils.float_of_real_constant r in
Nir.Node.div_float (convert_term a) f
| _ ->
IR.Node.create
(Constant
{
data =
IR.Gentensor.create_1_float (Utils.float_of_real_constant r);
})
| Tapp (ls, []) -> get_input ls
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add ->
IR.Node.create
(Add { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub ->
IR.Node.create
(Sub { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul ->
IR.Node.create
(Mul { input1 = convert_term a; input2 = convert_term b })
| Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> (
match b.t_node with
| Tconst (Why3.Constant.ConstReal r) ->
let f = Utils.float_of_real_constant r in
Nir.Node.div_float (convert_term a) f
| _ ->
IR.Node.create
(Div { input1 = convert_term a; input2 = convert_term b }))
| Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg ->
Nir.Node.mul_float (convert_term a) (-1.)
| Tconst _
| Tapp (_, _)
| Tif (_, _, _)
| Tlet (_, _)
| Tbinop (_, _, _)
| Tcase (_, _)
| Tnot _ | Ttrue | Tfalse ->
Logging.not_implemented_yet (fun m ->
m "Why3 term to IR: %a" Why3.Pretty.print_term term)
| Tvar _ | Teps _ | Tquant (_, _) ->
Logging.not_implemented_yet (fun m ->
m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term))
(Div { input1 = convert_term a; input2 = convert_term b }))
| Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg ->
Nir.Node.mul_float (convert_term a) (-1.)
| Tconst _
| Tapp (_, _)
| Tif (_, _, _)
| Tlet (_, _)
| Tbinop (_, _, _)
| Tcase (_, _)
| Tnot _ | Ttrue | Tfalse ->
Logging.not_implemented_yet (fun m ->
m "Why3 term to IR: %a" Why3.Pretty.print_term term)
| Tvar _ | Teps _ | Tquant (_, _) ->
Logging.not_implemented_yet (fun m ->
m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term)
in
(* Start actual term conversion. *)
let outputs =
......@@ -259,19 +260,18 @@ let check_if_new_nn_needed env input_vars outputs =
(* Choose the term pattern for starting the conversion to ONNX. *)
let has_start_pattern env term =
let th_model = Symbols.Model.create env in
let th_f = Symbols.Float64.create env in
match match_nn_app th_model term with
| Some _ -> true
| None -> (
match term.Why3.Term.t_node with
| Tapp ({ ls_value = Some ty; _ }, []) ->
(* input symbol *) Why3.Ty.ty_equal ty th_f.ty
| Tapp (ls_app, _) ->
List.mem
[ th_f.add; th_f.sub; th_f.mul; th_f.div ]
~equal:Why3.Term.ls_equal ls_app
| _ -> false)
match term.Why3.Term.t_node with
| Tapp (ls_get (* [ ] *), _)
when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") ->
true
| Tapp ({ ls_value = Some ty; _ }, []) ->
(* input symbol *) Why3.Ty.ty_equal ty th_f.ty
| Tapp (ls_app, _) ->
List.mem
[ th_f.add; th_f.sub; th_f.mul; th_f.div ]
~equal:Why3.Term.ls_equal ls_app
| _ -> false
(* Abstract terms that contains neural network applications. *)
let abstract_nn_term env =
......
......@@ -936,9 +936,9 @@ Test verify on acasxu
out/caisar_1.onnx has 1 input nodes
{'name': '221', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
out/caisar_2.onnx has 1 input nodes
{'name': '440', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
{'name': '439', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
out/caisar_3.onnx has 1 input nodes
{'name': '588', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
{'name': '587', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
out/caisar_4.onnx has 1 input nodes
{'name': '815', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
{'name': '814', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
5 files checked
......@@ -146,7 +146,7 @@ Test interpret fail
> EOF
$ caisar verify --prover nnenum file.mlw
[ERROR] "file.mlw", line 12, characters 14-23:
[ERROR] "file.mlw", line 12, characters 24-26:
Index constant 10 is out-of-bounds [0,4]
$ cat > file.mlw <<EOF
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment