diff --git a/lib/onnx/dune b/lib/onnx/dune index 67bb0943c99bbe97d57ac1ccbf9261f0f9860edb..65ba5412c35cefcd13c2ea542342c4e022282fc5 100644 --- a/lib/onnx/dune +++ b/lib/onnx/dune @@ -4,12 +4,11 @@ (libraries base stdio piqirun.pb ocaml-protoc-plugin) (synopsis "ONNX parser")) (rule - (targets onnx_protoc.ml onnx_protoc.mli) + (targets onnx_protoc.ml) (action (run ./generate_onnx_interface.sh) ) (deps onnx_protoc.proto generate_onnx_interface.sh) - (mode - promote)) + ) diff --git a/lib/onnx/generate_onnx_interface.sh b/lib/onnx/generate_onnx_interface.sh index 8b38cd016e29e570e9e2fd6fda2249fa7fa4452a..5ba8d64d663c92f7e2e1026f4f8065a087777bc8 100755 --- a/lib/onnx/generate_onnx_interface.sh +++ b/lib/onnx/generate_onnx_interface.sh @@ -1,3 +1,2 @@ #!/bin/sh protoc --ocaml_out=. onnx_protoc.proto -ocamlfind ocamlc -package ocaml-protoc-plugin -i onnx_protoc.ml > onnx_protoc.mli diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 4b169b93f19fa002d2eb0dfe1b5778768f47be08..427f0469a2acc6043e8dc5d6cca2b5201b91af00 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -11,38 +11,33 @@ module Oproto = Onnx_protoc (* Autogenerated during compilation *) module Oprotom = Oproto.Onnx.ModelProto -type t = Oproto.Onnx.ModelProto.t [@@deriving show] +type t = { + n_inputs : int; (** Number of inputs. *) + n_outputs : int; (** Number of outputs. *) +} (* ONNX format handling. *) -let parse_in_channel in_channel = - let open Result in - try - let buf = Stdio.In_channel.input_all in_channel in - let reader = Ocaml_protoc_plugin.Reader.create buf in - let res = - match Oprotom.from_proto reader with - | Ok r -> Ok r - | _ -> Error "Error parsing protobuf" - in - res - with - | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) +let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) = + match List.nth s 0 with + | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ } + -> + v + | _ -> [] -let get_input_output_shape (model : t) = +let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) = + List.fold ~init:1 dim ~f:(fun acc x -> + match x.value with + | `Dim_value v -> acc * v + | `Dim_param _ -> acc + | `not_set -> acc) + +let get_input_output_dim (model : Oprotom.t) = let ins, outs = match model.graph with | Some g -> (Some g.input, Some g.output) | None -> (None, None) in - let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) = - match List.nth s 0 with - | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ } - -> - v - | _ -> [] - in let input_shape, output_shape = match (ins, outs) with | Some i, Some o -> (get_nested_dims i, get_nested_dims o) @@ -50,17 +45,24 @@ let get_input_output_shape (model : t) = in (* TODO: here we only get the flattened dimension of inputs and outputs, but more interesting parsing could be done later on. *) - let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) = - List.fold ~init:1 dim ~f:(fun acc x -> - match x.value with - | `Dim_value v -> acc * v - | `Dim_param _ -> acc - | `not_set -> acc) - in let input_flat_dim = flattened_dim input_shape in let output_flat_dim = flattened_dim output_shape in (input_flat_dim, output_flat_dim) +let parse_in_channel in_channel = + let open Result in + try + let buf = Stdio.In_channel.input_all in_channel in + let reader = Ocaml_protoc_plugin.Reader.create buf in + match Oprotom.from_proto reader with + | Ok r -> + let n_inputs, n_outputs = get_input_output_dim r in + Ok { n_inputs; n_outputs } + | _ -> Error "Error parsing protobuf" + with + | Sys_error s -> Error s + | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + let parse filename = let in_channel = Stdlib.open_in filename in Fun.protect diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index b93ffd3b5699db96057b907c9ccb07e5051b1227..b017810b076de7a704598079545f57d161a5344e 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -4,14 +4,11 @@ (* *) (**************************************************************************) -module Oproto = Onnx_protoc - -type t = Onnx_protoc.Onnx.ModelProto.t -[@@deriving show] +type t = private { + n_inputs : int; (** Number of inputs. *) + n_outputs : int; (** Number of outputs. *) +} +(** ONNX model metadata. *) val parse : string -> (t, string) Result.t (** Parse an ONNX file. *) - -val get_input_output_shape : t -> int*int -(** Get the flattened input and output shape of a neural - network in an ONNX model. *) diff --git a/src/language.ml b/src/language.ml index a9e6d176a41aed8ac06baabac66b055e39769ab0..83f6936a76062be3f51547289df1b77810b47052 100644 --- a/src/language.ml +++ b/src/language.ml @@ -6,7 +6,7 @@ open Base -(* -- Support for the NNet and ONNX neural network format. *) +(* -- Support for the NNet and ONNX neural network formats. *) type ioshape = { nb_inputs : int; @@ -55,9 +55,7 @@ let onnx_parser env _ filename _ = let header = Onnx.parse filename in match header with | Error s -> Loc.errorm "%s" s - | Ok model -> - let input_flat_dim, output_flat_dim = Onnx.get_input_output_shape model in - register_astuple input_flat_dim output_flat_dim filename env + | Ok header -> register_astuple header.n_inputs header.n_outputs filename env let register_nnet_support () = Why3.( diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 3e84a490bb83c26db4d97f8c8c3574ce804b713f..15e77170ec6e9c4b6ba86803076840f2c217bab4 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -1,5 +1,4 @@ (**************************************************************************) - (* *) (* This file is part of CAISAR. *) (* *)