Skip to content
Snippets Groups Projects
Commit 6dbda66b authored by Julien Girard-Satabin's avatar Julien Girard-Satabin Committed by Michele Alberti
Browse files

Able to parse inputs and outputs shape for ONNX

parent 73f663dc
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
module Opiqi = Onnx_piqi
type t = Onnx_piqi.Model_proto.t type t = Onnx_piqi.Model_proto.t
val parse : string -> (t, string) Result.t val parse : string -> (t, string) Result.t
......
This diff is collapsed.
This diff is collapsed.
(executable (executable
(name main) (name main)
(public_name caisar) (public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3 dune-site) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet onnx why3 dune-site)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar) (package caisar)
) )
......
...@@ -6,18 +6,18 @@ ...@@ -6,18 +6,18 @@
open Base open Base
(* -- Support for the NNet neural network format. *) (* -- Support for the NNet and ONNX neural network format. *)
type nnet = { type net = {
nb_inputs : int; nb_inputs : int;
nb_outputs : int; nb_outputs : int;
ty_data : Why3.Ty.ty; ty_data : Why3.Ty.ty;
filename : string; filename : string;
} }
let loaded_nnets = Why3.Term.Hls.create 10 let loaded_nets = Why3.Term.Hls.create 10
let lookup_loaded_nnets = Why3.Term.Hls.find_opt loaded_nnets let lookup_loaded_nets = Why3.Term.Hls.find_opt loaded_nets
let nnet_parser env _ filename _ = let nnet_parser env _ filename _ =
let open Why3 in let open Why3 in
...@@ -41,7 +41,7 @@ let nnet_parser env _ filename _ = ...@@ -41,7 +41,7 @@ let nnet_parser env _ filename _ =
(List.init header.n_inputs ~f) (List.init header.n_inputs ~f)
(Ty.ty_tuple (List.init header.n_outputs ~f)) (Ty.ty_tuple (List.init header.n_outputs ~f))
in in
Why3.Term.Hls.add loaded_nnets ls_nnet_apply Why3.Term.Hls.add loaded_nets ls_nnet_apply
{ {
filename; filename;
nb_inputs = header.n_inputs; nb_inputs = header.n_inputs;
...@@ -54,7 +54,79 @@ let nnet_parser env _ filename _ = ...@@ -54,7 +54,79 @@ let nnet_parser env _ filename _ =
in in
Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)
let onnx_parser env _ filename _ =
let open Why3 in
let header = Onnx.parse filename in
match header with
| Error s -> Loc.errorm "%s" s
| Ok model ->
let onnx = Pmodule.read_module env [ "caisar" ] "NNet" in
let onnx_input_type =
Ty.ty_app
Theory.(ns_find_ts onnx.mod_theory.th_export [ "input_type" ])
[]
in
let id_as_tuple = Ident.id_fresh "AsTuple" in
let th_uc = Pmodule.create_module env id_as_tuple in
let th_uc = Pmodule.use_export th_uc onnx in
let ins, outs = match model.graph with
| Some g -> Some g.input, Some g.output
| None -> None, None
in
let get_nested_dims (s:Onnx.Opiqi.value_info_proto list) = match List.nth_exn s 0 with
| {type_ = Some {tensor_type =
Some {shape =
Some v; _};_}; _} ->
v.dim
| _ -> []
in
let input_shape, output_shape = match ins, outs with
| Some i, Some o ->
get_nested_dims i,
get_nested_dims o
| _-> [], []
in
(*TODO: here we only get the flattened dimension
* of inputs and outputs, but more interesting parsing
* could be done later on *)
let flattened_dim (dim:Onnx.Opiqi.tensor_shape_proto_dimension list) =
List.fold ~init:1 dim
~f:(fun acc x -> match x.dim_value with
| Some v -> acc * (Int64.to_int_exn v)
| None -> acc)
in
let input_flat_dim, output_flat_dim =
(flattened_dim input_shape),
(flattened_dim output_shape)
in
let ls_onnx_apply =
(*TODO: find out input and output size for ONNX*)
let f _ = onnx_input_type in
Term.create_fsymbol
(Ident.id_fresh "onnx_apply")
(List.init input_flat_dim ~f)
(Ty.ty_tuple (List.init output_flat_dim ~f))
in
Why3.Term.Hls.add loaded_nets ls_onnx_apply
{
filename;
nb_inputs = input_flat_dim;
nb_outputs = output_flat_dim;
ty_data = onnx_input_type;
};
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_onnx_apply))
in
Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)
let register_nnet_support () = let register_nnet_support () =
Why3.( Why3.(
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
"NNet" [ "nnet" ] nnet_parser) "NNet" [ "nnet" ] nnet_parser)
let register_onnx_support () =
Why3.(
Env.register_format ~desc:"ONNX format" Pmodule.mlw_language
"ONNX" [ "ONNX" ] onnx_parser)
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
type nnet = { type net = {
nb_inputs : int; nb_inputs : int;
nb_outputs : int; nb_outputs : int;
ty_data : Why3.Ty.ty; ty_data : Why3.Ty.ty;
filename : string; filename : string;
} }
val lookup_loaded_nnets : Why3.Term.lsymbol -> nnet option val lookup_loaded_nets : Why3.Term.lsymbol -> net option
(** @return the filename of a nnet Why3 representation. *) (** @return the filename of a nnet Why3 representation. *)
val register_nnet_support : unit -> unit val register_nnet_support : unit -> unit
(** Register nnet parser. *) (** Register NNet parser. *)
val register_onnx_support : unit -> unit
(** Register ONNX parser. *)
...@@ -24,7 +24,7 @@ let get_input_variables = ...@@ -24,7 +24,7 @@ let get_input_variables =
let rec aux acc (term : Why3.Term.term) = let rec aux acc (term : Why3.Term.term) =
match term.t_node with match term.t_node with
| Why3.Term.Tapp (ls, args) -> ( | Why3.Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nnets ls with match Language.lookup_loaded_nets ls with
| None -> acc | None -> acc
| Some _ -> | Some _ ->
let add i acc = function let add i acc = function
...@@ -48,7 +48,7 @@ let simplify_goal env input_variables = ...@@ -48,7 +48,7 @@ let simplify_goal env input_variables =
let rec aux hls (term : Why3.Term.term) = let rec aux hls (term : Why3.Term.term) =
match term.t_node with match term.t_node with
| Why3.Term.Tapp (ls, _) -> ( | Why3.Term.Tapp (ls, _) -> (
match Language.lookup_loaded_nnets ls with match Language.lookup_loaded_nets ls with
| None -> Why3.Term.t_map (aux hls) term | None -> Why3.Term.t_map (aux hls) term
| Some nnet -> | Some nnet ->
let outputs = let outputs =
......
...@@ -8,6 +8,7 @@ open Base ...@@ -8,6 +8,7 @@ open Base
module Filename = Caml.Filename module Filename = Caml.Filename
let () = Language.register_nnet_support () let () = Language.register_nnet_support ()
let () = Language.register_onnx_support ()
let create_env loadpath = let create_env loadpath =
let config = Autodetection.autodetect ~debug:true () in let config = Autodetection.autodetect ~debug:true () in
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment