diff --git a/lib/onnx/dune b/lib/onnx/dune
index 67bb0943c99bbe97d57ac1ccbf9261f0f9860edb..65ba5412c35cefcd13c2ea542342c4e022282fc5 100644
--- a/lib/onnx/dune
+++ b/lib/onnx/dune
@@ -4,12 +4,11 @@
   (libraries base stdio piqirun.pb ocaml-protoc-plugin)
   (synopsis "ONNX parser"))
 (rule
-  (targets onnx_protoc.ml onnx_protoc.mli)
+  (targets onnx_protoc.ml)
   (action
     (run ./generate_onnx_interface.sh)
     )
   (deps
     onnx_protoc.proto
     generate_onnx_interface.sh)
-  (mode
-    promote))
+  )
diff --git a/lib/onnx/generate_onnx_interface.sh b/lib/onnx/generate_onnx_interface.sh
index 8b38cd016e29e570e9e2fd6fda2249fa7fa4452a..5ba8d64d663c92f7e2e1026f4f8065a087777bc8 100755
--- a/lib/onnx/generate_onnx_interface.sh
+++ b/lib/onnx/generate_onnx_interface.sh
@@ -1,3 +1,2 @@
 #!/bin/sh
 protoc --ocaml_out=. onnx_protoc.proto
-ocamlfind ocamlc -package ocaml-protoc-plugin -i onnx_protoc.ml > onnx_protoc.mli
diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml
index 4b169b93f19fa002d2eb0dfe1b5778768f47be08..427f0469a2acc6043e8dc5d6cca2b5201b91af00 100644
--- a/lib/onnx/onnx.ml
+++ b/lib/onnx/onnx.ml
@@ -11,38 +11,33 @@ module Oproto = Onnx_protoc (* Autogenerated during compilation *)
 
 module Oprotom = Oproto.Onnx.ModelProto
 
-type t = Oproto.Onnx.ModelProto.t [@@deriving show]
+type t = {
+  n_inputs : int;  (** Number of inputs. *)
+  n_outputs : int;  (** Number of outputs. *)
+}
 
 (* ONNX format handling. *)
 
-let parse_in_channel in_channel =
-  let open Result in
-  try
-    let buf = Stdio.In_channel.input_all in_channel in
-    let reader = Ocaml_protoc_plugin.Reader.create buf in
-    let res =
-      match Oprotom.from_proto reader with
-      | Ok r -> Ok r
-      | _ -> Error "Error parsing protobuf"
-    in
-    res
-  with
-  | Sys_error s -> Error s
-  | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
+let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
+  match List.nth s 0 with
+  | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
+    ->
+    v
+  | _ -> []
 
-let get_input_output_shape (model : t) =
+let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
+  List.fold ~init:1 dim ~f:(fun acc x ->
+    match x.value with
+    | `Dim_value v -> acc * v
+    | `Dim_param _ -> acc
+    | `not_set -> acc)
+
+let get_input_output_dim (model : Oprotom.t) =
   let ins, outs =
     match model.graph with
     | Some g -> (Some g.input, Some g.output)
     | None -> (None, None)
   in
-  let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
-    match List.nth s 0 with
-    | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
-      ->
-      v
-    | _ -> []
-  in
   let input_shape, output_shape =
     match (ins, outs) with
     | Some i, Some o -> (get_nested_dims i, get_nested_dims o)
@@ -50,17 +45,24 @@ let get_input_output_shape (model : t) =
   in
   (* TODO: here we only get the flattened dimension of inputs and outputs, but
      more interesting parsing could be done later on. *)
-  let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
-    List.fold ~init:1 dim ~f:(fun acc x ->
-      match x.value with
-      | `Dim_value v -> acc * v
-      | `Dim_param _ -> acc
-      | `not_set -> acc)
-  in
   let input_flat_dim = flattened_dim input_shape in
   let output_flat_dim = flattened_dim output_shape in
   (input_flat_dim, output_flat_dim)
 
+let parse_in_channel in_channel =
+  let open Result in
+  try
+    let buf = Stdio.In_channel.input_all in_channel in
+    let reader = Ocaml_protoc_plugin.Reader.create buf in
+    match Oprotom.from_proto reader with
+    | Ok r ->
+      let n_inputs, n_outputs = get_input_output_dim r in
+      Ok { n_inputs; n_outputs }
+    | _ -> Error "Error parsing protobuf"
+  with
+  | Sys_error s -> Error s
+  | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
+
 let parse filename =
   let in_channel = Stdlib.open_in filename in
   Fun.protect
diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli
index b93ffd3b5699db96057b907c9ccb07e5051b1227..b017810b076de7a704598079545f57d161a5344e 100644
--- a/lib/onnx/onnx.mli
+++ b/lib/onnx/onnx.mli
@@ -4,14 +4,11 @@
 (*                                                                        *)
 (**************************************************************************)
 
-module Oproto = Onnx_protoc
-
-type t = Onnx_protoc.Onnx.ModelProto.t
-[@@deriving show]
+type t = private {
+  n_inputs : int;  (** Number of inputs. *)
+  n_outputs : int;  (** Number of outputs. *)
+}
+(** ONNX model metadata. *)
 
 val parse : string -> (t, string) Result.t
 (** Parse an ONNX file. *)
-
-val get_input_output_shape : t -> int*int
-(** Get the flattened input and output shape of a neural
-    network in an ONNX model. *)
diff --git a/src/language.ml b/src/language.ml
index a9e6d176a41aed8ac06baabac66b055e39769ab0..83f6936a76062be3f51547289df1b77810b47052 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -6,7 +6,7 @@
 
 open Base
 
-(* -- Support for the NNet and ONNX neural network format. *)
+(* -- Support for the NNet and ONNX neural network formats. *)
 
 type ioshape = {
   nb_inputs : int;
@@ -55,9 +55,7 @@ let onnx_parser env _ filename _ =
   let header = Onnx.parse filename in
   match header with
   | Error s -> Loc.errorm "%s" s
-  | Ok model ->
-    let input_flat_dim, output_flat_dim = Onnx.get_input_output_shape model in
-    register_astuple input_flat_dim output_flat_dim filename env
+  | Ok header -> register_astuple header.n_inputs header.n_outputs filename env
 
 let register_nnet_support () =
   Why3.(
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 3e84a490bb83c26db4d97f8c8c3574ce804b713f..15e77170ec6e9c4b6ba86803076840f2c217bab4 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -1,5 +1,4 @@
 (**************************************************************************)
-
 (*                                                                        *)
 (*  This file is part of CAISAR.                                          *)
 (*                                                                        *)