Skip to content
Snippets Groups Projects
Commit 764ea561 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

Harmonized with nnet interface

parent 287a6972
No related branches found
No related tags found
No related merge requests found
......@@ -4,12 +4,11 @@
(libraries base stdio piqirun.pb ocaml-protoc-plugin)
(synopsis "ONNX parser"))
(rule
(targets onnx_protoc.ml onnx_protoc.mli)
(targets onnx_protoc.ml)
(action
(run ./generate_onnx_interface.sh)
)
(deps
onnx_protoc.proto
generate_onnx_interface.sh)
(mode
promote))
)
#!/bin/sh
protoc --ocaml_out=. onnx_protoc.proto
ocamlfind ocamlc -package ocaml-protoc-plugin -i onnx_protoc.ml > onnx_protoc.mli
......@@ -11,38 +11,33 @@ module Oproto = Onnx_protoc (* Autogenerated during compilation *)
module Oprotom = Oproto.Onnx.ModelProto
type t = Oproto.Onnx.ModelProto.t [@@deriving show]
type t = {
n_inputs : int; (** Number of inputs. *)
n_outputs : int; (** Number of outputs. *)
}
(* ONNX format handling. *)
let parse_in_channel in_channel =
let open Result in
try
let buf = Stdio.In_channel.input_all in_channel in
let reader = Ocaml_protoc_plugin.Reader.create buf in
let res =
match Oprotom.from_proto reader with
| Ok r -> Ok r
| _ -> Error "Error parsing protobuf"
in
res
with
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
match List.nth s 0 with
| Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
->
v
| _ -> []
let get_input_output_shape (model : t) =
let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
List.fold ~init:1 dim ~f:(fun acc x ->
match x.value with
| `Dim_value v -> acc * v
| `Dim_param _ -> acc
| `not_set -> acc)
let get_input_output_dim (model : Oprotom.t) =
let ins, outs =
match model.graph with
| Some g -> (Some g.input, Some g.output)
| None -> (None, None)
in
let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
match List.nth s 0 with
| Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
->
v
| _ -> []
in
let input_shape, output_shape =
match (ins, outs) with
| Some i, Some o -> (get_nested_dims i, get_nested_dims o)
......@@ -50,17 +45,24 @@ let get_input_output_shape (model : t) =
in
(* TODO: here we only get the flattened dimension of inputs and outputs, but
more interesting parsing could be done later on. *)
let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
List.fold ~init:1 dim ~f:(fun acc x ->
match x.value with
| `Dim_value v -> acc * v
| `Dim_param _ -> acc
| `not_set -> acc)
in
let input_flat_dim = flattened_dim input_shape in
let output_flat_dim = flattened_dim output_shape in
(input_flat_dim, output_flat_dim)
let parse_in_channel in_channel =
let open Result in
try
let buf = Stdio.In_channel.input_all in_channel in
let reader = Ocaml_protoc_plugin.Reader.create buf in
match Oprotom.from_proto reader with
| Ok r ->
let n_inputs, n_outputs = get_input_output_dim r in
Ok { n_inputs; n_outputs }
| _ -> Error "Error parsing protobuf"
with
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
let parse filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
......
......@@ -4,14 +4,11 @@
(* *)
(**************************************************************************)
module Oproto = Onnx_protoc
type t = Onnx_protoc.Onnx.ModelProto.t
[@@deriving show]
type t = private {
n_inputs : int; (** Number of inputs. *)
n_outputs : int; (** Number of outputs. *)
}
(** ONNX model metadata. *)
val parse : string -> (t, string) Result.t
(** Parse an ONNX file. *)
val get_input_output_shape : t -> int*int
(** Get the flattened input and output shape of a neural
network in an ONNX model. *)
......@@ -6,7 +6,7 @@
open Base
(* -- Support for the NNet and ONNX neural network format. *)
(* -- Support for the NNet and ONNX neural network formats. *)
type ioshape = {
nb_inputs : int;
......@@ -55,9 +55,7 @@ let onnx_parser env _ filename _ =
let header = Onnx.parse filename in
match header with
| Error s -> Loc.errorm "%s" s
| Ok model ->
let input_flat_dim, output_flat_dim = Onnx.get_input_output_shape model in
register_astuple input_flat_dim output_flat_dim filename env
| Ok header -> register_astuple header.n_inputs header.n_outputs filename env
let register_nnet_support () =
Why3.(
......
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment