diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml
deleted file mode 100644
index ad7f5bd5f84df85af6356fcbcd3c28cac917bca2..0000000000000000000000000000000000000000
--- a/lib/ir/nier_cfg.ml
+++ /dev/null
@@ -1,515 +0,0 @@
-open Base
-open Stdio
-open Bigarray
-
-module Tensor = struct
-  type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t
-  type shape = int array [@@deriving show]
-
-  type ('a, 'b) t_kind =
-    | K_int : (int64, int64_elt) t_kind
-    | K_float : (float, float64_elt) t_kind
-
-  let create : type a b. shape -> (a, b) t_kind -> (a, b) t =
-   fun shape -> function
-    | K_float -> Genarray.create float64 c_layout shape
-    | K_int -> Genarray.create int64 c_layout shape
-
-  let unsqueeze ~sh1 ~sh2 =
-    let sh1, sh2 = (Array.to_list sh1, Array.to_list sh2) in
-    let longest, shortest =
-      match List.length sh1 > List.length sh2 with
-      | true -> (sh1, sh2)
-      | false -> (sh2, sh1)
-    in
-    (*find the index of the potential additional dimension*)
-    let where_zero =
-      match List.nth_exn longest 0 with
-      | 0 -> Some 0
-      | _ -> (
-        match List.last_exn longest with
-        | 0 -> Some (List.length longest - 1)
-        | _ -> None)
-    in
-    match where_zero with
-    | Some idx -> (
-      match List.sub longest ~pos:idx ~len:(List.length shortest) with
-      | [] -> None
-      | _ -> Some (Array.of_list longest))
-    | None -> None
-
-  let get t idx = Genarray.get t idx
-  let set t idx v = Genarray.set t idx v
-
-  let all_coords sh =
-    let sh = Array.to_list sh in
-    let rec ranges acc shape =
-      match shape with
-      | x :: y -> ranges (List.init x ~f:(fun i -> i) :: acc) y
-      | [] -> acc
-      (* a list containing a list of all possible indexes, for each dimension *)
-    in
-    let xxs = ranges [] sh in
-    (* add to each element of the list of all possible coordinates all*)
-    (* * possible indexes ... *)
-    let aux acc xs =
-      List.concat
-      @@ List.map xs ~f:(fun x -> List.map ~f:(fun lt -> x :: lt) acc)
-      (* ... for each dimension, starting from an empty list of*)
-      (* * possible coordinates *)
-    in
-    List.fold xxs ~init:[ [] ] ~f:aux
-
-  let flatten t =
-    let shape = Genarray.dims t in
-    let all_idxs = all_coords shape in
-    List.init (List.length all_idxs) ~f:(fun i ->
-      get t (Array.of_list @@ List.nth_exn all_idxs i))
-
-  let get_shape t = Genarray.dims t
-
-  let equal f t1 t2 =
-    let t1_sh = get_shape t1 and t2_sh = get_shape t2 in
-    if Array.equal ( = ) t1_sh t2_sh
-    then
-      let all_idxs = all_coords (get_shape t1) in
-      List.fold
-        ~f:(fun acc x ->
-          if acc
-          then f (get t1 (Array.of_list x)) (get t2 (Array.of_list x))
-          else false)
-        all_idxs ~init:true
-    else false
-
-  let num_neurons sh = Array.fold ~init:1 ~f:(fun x y -> x * y) sh
-
-  let get_flatnd_idx ~idx ~sh flt =
-    let sh = Array.to_list sh in
-    let pop_sh = List.tl_exn sh @ [ 1 ] in
-    let rec get_factors_from_sh sh_f l =
-      match sh_f with
-      | [] -> List.rev l
-      | _ ->
-        get_factors_from_sh (List.tl_exn sh_f)
-          (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l)
-    in
-    let factors = get_factors_from_sh pop_sh [] in
-    let coord_in_data =
-      match
-        List.fold2
-          ~f:(fun x y z -> x + (y * z))
-          ~init:0 (Array.to_list idx) factors
-      with
-      | List.Or_unequal_lengths.Ok i -> i
-      | List.Or_unequal_lengths.Unequal_lengths ->
-        failwith "Unequal lengths in get_flatnd_idx"
-    in
-    List.nth_exn flt coord_in_data
-
-  let transpose_2d _t = assert false
-end
-
-(* TODO: maybe add markers for special nodes, to reflect they are the inputs and
-   outputs of the neural network? *)
-module Node = struct
-  type shape = int array
-
-  let show_shape sh =
-    let sh = Array.to_list sh in
-    match sh with
-    | [] -> "[]"
-    | x :: y ->
-      "[" ^ Int.to_string x
-      ^ List.fold ~init:"" ~f:(fun str e -> str ^ ";" ^ Int.to_string e) y
-      ^ "]"
-
-  type operator =
-    | Add
-    | Sub
-    | Mul
-    | Div
-    | Matmul
-    | Gemm
-    | LogSoftmax
-    | ReLu
-    | Transpose
-    | Squeeze
-    | MaxPool
-    | Conv
-    | Reshape
-    | Flatten
-    | Identity
-    | Constant
-    | NO_OP
-    | RW_Linearized_ReLu
-    | Gather
-    | ReduceSum
-    | GatherND
-    | RandomNormal
-    | Abs
-    | Log
-
-  let str_op o =
-    match o with
-    | Add -> "Add"
-    | Sub -> "Sub"
-    | Mul -> "Mul"
-    | Div -> "Div"
-    | Matmul -> "Matmul"
-    | Gemm -> "Gemm"
-    | LogSoftmax -> "LogSoftmax"
-    | ReLu -> "ReLu"
-    | Transpose -> "Transpose"
-    | Squeeze -> "Squeeze"
-    | MaxPool -> "MaxPool"
-    | Conv -> "Conv"
-    | Reshape -> "Reshape"
-    | Flatten -> "Flatten"
-    | Identity -> "Identity"
-    | Constant -> "Constant"
-    | NO_OP -> "NO_OP"
-    | RW_Linearized_ReLu -> "RW_Linearized_ReLu"
-    | Gather -> "Gather"
-    | ReduceSum -> "ReduceSum"
-    | GatherND -> "GatherND"
-    | RandomNormal -> "RandomNormal"
-    | Abs -> "Abs"
-    | Log -> "Log"
-
-  type ksize = Ksize of shape
-  type stride = Stride of shape
-  type pads = Pads of shape
-  type dilations = Dilations of shape
-
-  type operator_parameters =
-    | Pool_params of (ksize * stride option * pads option * dilations option)
-    | Conv_params of (ksize * stride option * pads option * dilations option)
-    | Transpose_params of shape
-    | RW_Linearized_ReLu_params of
-        (bool list list * ((string, float) Base.Hashtbl.t list * int))
-    | Gather_params of int
-    | ReduceSum_params of int * int
-    | RandomNormal_params of int * float * float * float * shape
-
-  let str_op_params p =
-    match p with
-    | Transpose_params s ->
-      let str_sh = show_shape s in
-      "Transpose params: " ^ str_sh
-    | Pool_params (Ksize k, s, p, d) | Conv_params (Ksize k, s, p, d) ->
-      let str_k = show_shape k
-      and str_s = match s with None -> "" | Some (Stride ss) -> show_shape ss
-      and str_p = match p with None -> "" | Some (Pads pp) -> show_shape pp
-      and str_d =
-        match d with None -> "" | Some (Dilations dd) -> show_shape dd
-      in
-      "Pool params: KSIZE: " ^ str_k ^ ", Pads: " ^ str_p ^ ", Stride: " ^ str_s
-      ^ ", Dilations: " ^ str_d
-    | RW_Linearized_ReLu_params l ->
-      (* Only displays the activation scheme on the ReLU node *)
-      let activations = fst l in
-      let act' =
-        List.map
-          ~f:(fun l1 ->
-            List.map
-              ~f:(fun b -> match b with true -> "true" | false -> "false")
-              l1)
-          activations
-      in
-      let act'' =
-        List.map ~f:(fun l -> "[" ^ String.concat ~sep:";" l ^ "]") act'
-      in
-      let act''' = "[" ^ String.concat ~sep:";" act'' ^ "]" in
-      "RW_Linearized_ReLu_params: " ^ act'''
-    | Gather_params a -> "Gather params: " ^ Int.to_string a
-    | ReduceSum_params (a, b) ->
-      "ReduceSum params: " ^ Int.to_string a ^ Int.to_string b
-    | RandomNormal_params (a, b, c, d, s) ->
-      let str_sh = show_shape s in
-      "RandomNormal params: " ^ Int.to_string a ^ Float.to_string b
-      ^ Float.to_string c ^ Float.to_string d ^ str_sh
-
-  type ('a, 'b) t = {
-    id : int;
-    name : string option;
-    shape : shape;
-    operator : operator;
-    operator_parameters : operator_parameters option;
-    pred : string list;
-    succ : string list;
-    tensor : ('a, 'b) Tensor.t option;
-  }
-
-  let compare v1 v2 = Stdlib.compare v1.id v2.id
-  let hash (v : ('a, 'b) t) = v.id
-  let equal v1 v2 = v1.id = v2.id
-
-  let create ~id ~name ~sh ~op ~op_p ~pred ~succ ~tensor =
-    {
-      id;
-      name;
-      shape = sh;
-      operator = op;
-      operator_parameters = op_p;
-      pred;
-      succ;
-      tensor;
-    }
-
-  let get_name t = match t.name with Some n -> n | None -> "C_NODE"
-  let get_shape t = t.shape
-  let get_op t = t.operator
-  let get_tensor t = t.tensor
-  let get_pred_list t = t.pred
-  let get_succ_list t = t.succ
-  let is_data_node t = match get_tensor t with None -> false | Some _ -> true
-
-  (* TODO: some flags on the node would be cleaner than this*)
-  let is_input_node t = List.equal String.equal t.pred [ "NO_INPUT" ]
-  let is_output_node t = List.equal String.equal t.succ [ "NO_OUTPUT" ]
-
-  let num_neurons t =
-    match get_shape t with
-    | [||] -> 0
-    | l -> Array.fold ~init:1 ~f:(fun x acc -> x * acc) l
-
-  let show n f =
-    let id = Int.to_string n.id in
-    let name = get_name n
-    and operator = str_op n.operator
-    and operator_parameters =
-      match n.operator_parameters with
-      | Some p -> str_op_params p
-      | None -> "no parameters"
-    and shape = show_shape n.shape
-    and prevs =
-      List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_pred_list n)
-    and nexts =
-      List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_succ_list n)
-    and tensor =
-      match n.tensor with
-      (*limit of size for tensor strings, complying with
-       * dot string size limit of 16Ko *)
-      | Some t ->
-        let display_indices =
-          let all_indices = Tensor.all_coords (Tensor.get_shape t) in
-          if List.length all_indices > 10
-          then
-            let rec firstk k xs =
-              match xs with
-              | [] -> failwith "firstk"
-              | x :: xs -> if k = 1 then [ x ] else x :: firstk (k - 1) xs
-            in
-            firstk 10 all_indices
-          else all_indices
-        in
-        let t_value_string f =
-          List.fold_left
-            ~f:(fun acc l ->
-              acc
-              ^ show_shape (Array.of_list l)
-              ^ ": "
-              ^ f (Tensor.get t (Array.of_list l))
-              ^ "\n")
-            ~init:"" display_indices
-        in
-        "Tensor value\n: " ^ t_value_string f ^ "\nShape: "
-        ^ show_shape (Tensor.get_shape t)
-      | None -> "No tensor in node"
-    in
-    "ID :" ^ id ^ "\nNAME: " ^ name ^ "\nOP: " ^ operator ^ "\nOP PARAMS:"
-    ^ operator_parameters ^ "\nSHAPE: " ^ shape ^ "\nPREVS: " ^ prevs
-    ^ "\nNEXTS: " ^ nexts ^ "\nTENSORS INFOS:" ^ tensor
-end
-
-module type VInput = sig
-  type l
-  type r
-
-  val convert_f : l -> string
-end
-
-module MakeVertex (I : VInput) = struct
-  type t = (I.l, I.r) Node.t
-
-  let compare = Node.compare
-  let hash = Node.hash
-  let equal = Node.equal
-  let convert_f = I.convert_f
-
-  type label = string
-
-  let label (n : t) = match n.Node.name with Some n -> n | None -> ""
-  let create _name = assert false
-end
-
-module Edge = struct
-  type t = string
-
-  let compare = Stdlib.compare
-  let equal = phys_equal
-  let default = ""
-end
-
-module NierCFG (I : VInput) = struct
-  module Vertex = MakeVertex (I)
-  include Graph.Imperative.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge)
-
-  let convert_f = Vertex.convert_f
-  let vertex_list g = fold_vertex (fun x l -> x :: l) g []
-
-  let input_nodes g =
-    let input_criterion (v : ('a, 'b) Node.t) acc =
-      match v.id with 0 -> Some v | _ -> acc
-    in
-    match fold_vertex (fun v acc -> input_criterion v acc) g None with
-    | Some r -> [ r ]
-    | None -> failwith "Something strange, no node for describing inputs found"
-
-  let preds g v = pred g v
-
-  let preds_names g v =
-    let preds_list = pred_e g v in
-    List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) preds_list
-
-  let succs_names g v =
-    let succs_list = succ_e g v in
-    List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) succs_list
-
-  let succs g v = succ g v
-  let init_cfg = create ()
-
-  let find_vertices g f =
-    fold_vertex (fun x l -> if f x then x :: l else l) g []
-
-  let data_node_of n g =
-    fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None
-
-  let infer_shape g n in_shape ~on_backward =
-    let op = Node.get_op n in
-    match op with
-    | Node.Add -> (
-      match data_node_of n g with
-      | Some d_n -> Node.get_shape d_n
-      | None -> failwith "Error, Add operator lacks a data node")
-    | Node.ReLu -> in_shape
-    | Node.Matmul ->
-      let pad_left = function
-        | [] -> failwith "Impossible to pad empty shape"
-        | [ a ] -> [ 1; a ]
-        | x -> x
-      in
-      let pad_right = function
-        | [] -> failwith "Impossible to pad empty shape"
-        | [ a ] -> [ a; 1 ]
-        | x -> x
-      in
-      let rec one_padding l i =
-        if i <= 0 then l else one_padding (1 :: l) (i - 1)
-      in
-      let dn_shape =
-        match data_node_of n g with
-        | Some dn -> Node.get_shape dn
-        | None -> failwith "Error, MatMul operator lacks a data node"
-      in
-      (* Expected semantic:
-       * Matrix multiplication C = AB
-       * A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
-       * shape of b: b_sh
-       * shape of a: a_sh
-       * shape of c: c_sh
-       * It is expected here that B is the shape of the node
-       * yielding the data tensor in the NIER
-       *)
-      let check_matmul_size_ba ~b_sh ~a_sh =
-        let bdim2 = pad_left b_sh in
-        let adim2 = pad_right a_sh in
-        let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in
-        let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in
-        let rec infer_csize acc ad bd =
-          match (ad, bd) with
-          | [ m; n ], [ nn; p ] ->
-            if nn = n
-            then (n, List.append (List.rev acc) [ m; p ])
-            else failwith "size of matrices not adequate"
-          | a :: la, b :: lb ->
-            if a = b
-            then infer_csize (a :: acc) la lb
-            else if a = 1
-            then infer_csize (b :: acc) la lb
-            else if b = 1
-            then infer_csize (a :: acc) la lb
-            else failwith "Checking matmul_size failed: one discordance"
-          | _, _ -> failwith "Checking matmul_size failed"
-        in
-        infer_csize [] bdim adim
-      in
-      let check_matmul_size_bc ~b_sh ~c_sh =
-        let bdim2 = pad_left b_sh in
-        let cdim2 = pad_right c_sh in
-        let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in
-        let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in
-        let rec infer_asize acc bd cd =
-          match (bd, cd) with
-          | [ m; p ], [ n; pp ] ->
-            if pp = p
-            then (n, List.append (List.rev acc) [ n; m ])
-            else failwith "size of matrices not adequate"
-          | b :: lb, c :: lc ->
-            if b = c
-            then infer_asize (b :: acc) lb lc
-            else if b = 1
-            then infer_asize (b :: acc) lb lc
-            else if c = 1
-            then infer_asize (c :: acc) lb lc
-            else failwith "Checking matmul_size failed: one discordance"
-          | _, _ -> failwith "Checking matmul_size failed"
-        in
-        infer_asize [] bdim cdim
-      in
-      if on_backward
-      then
-        Array.of_list
-        @@ snd
-             (check_matmul_size_bc ~b_sh:(Array.to_list dn_shape)
-                ~c_sh:(Array.to_list in_shape))
-      else
-        Array.of_list
-        @@ snd
-             (check_matmul_size_ba ~b_sh:(Array.to_list in_shape)
-                ~a_sh:(Array.to_list dn_shape))
-    | a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a))
-end
-
-module NierCFGInt = NierCFG (struct
-  type l = int64
-  type r = int64_elt
-
-  let convert_f = Int64.to_string
-end)
-
-module NierCFGFloat = NierCFG (struct
-  type l = float
-  type r = float64_elt
-
-  let convert_f = Float.to_string
-end)
-
-module NierCFGDot = Graph.Graphviz.Dot (struct
-  include NierCFGFloat (* use the graph module from above *)
-
-  let node_label (v : vertex) = Node.show v convert_f
-  let edge_attributes (_, e, _) = [ `Label e; `Color 4711 ]
-  let default_edge_attributes _ = []
-  let get_subgraph _ = None
-  let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ]
-  let vertex_name (v : vertex) = Int.to_string v.id
-  let default_vertex_attributes _ = []
-  let graph_attributes _ = []
-end)
-
-let print_cfg_graph g = NierCFGDot.fprint_graph Stdlib.Format.std_formatter g
-
-let out_cfg_graph g =
-  let file = Out_channel.create "cfg.dot" in
-  NierCFGDot.output_graph file g
diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli
deleted file mode 100644
index acacd98ac6781091b86950638de57e4d9f8b1bc1..0000000000000000000000000000000000000000
--- a/lib/ir/nier_cfg.mli
+++ /dev/null
@@ -1,283 +0,0 @@
-(** This module defines the structure and interfaces for a Neural IntermediatE
-    Representation (NIER).
-
-    It is primarly designed as an intermediate state into producing verifiable
-    terms from an ONNX model. *)
-
-open Base
-open Bigarray
-
-(** {1 Tensor module} *)
-
-(** Tensors are multidimensional arrays used to represent numerical such as a
-    neural network weight *)
-
-module Tensor : sig
-  type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t
-  type shape = int array [@@deriving show]
-
-  val all_coords : shape -> int list list
-
-  (** [create sh] initialize a tensor with the given shape [sh] with a default
-      value, depending of the type of the tensor*)
-
-  type ('a, 'b) t_kind =
-    | K_int : (int64, int64_elt) t_kind
-    | K_float : (float, float64_elt) t_kind
-
-  val create : shape -> ('a, 'b) t_kind -> ('a, 'b) t
-
-  (** [get t idx] returns the value in tensor [t] stored at coordinates [idx].
-      Throw an error if the coordinate is invalid.*)
-
-  val get : ('a, 'b) t -> shape -> 'a
-
-  (** [set_idx t idx v] sets value [v] for tensor [t] at [idx]. Throw an error
-      if the coordinate is invalid.*)
-
-  val set : ('a, 'b) t -> shape -> 'a -> unit
-
-  (** [equal f t1 t2] applies [f] to all values of [t1] and [t2], and returns
-      true if all applications of f returned true. *)
-
-  val equal : ('a -> 'a -> bool) -> ('a, 'b) t -> ('a, 'b) t -> bool
-
-  (** [get_shape t] returns the shape of [t]. *)
-
-  val get_shape : ('a, 'b) t -> shape
-
-  (** [flatten t] returns a flattened version of [t]. *)
-
-  val flatten : ('a, 'b) t -> 'a list
-
-  (** [num_neurons sh] returns the total number of neurons given a shape *)
-
-  val num_neurons : shape -> int
-
-  (** [get flatnd_idx idx sh flt] returns the value that would be stored at
-      index [idx] under a tensor of shape [sh], given the flattened version of
-      this tensor [flt].*)
-
-  val get_flatnd_idx : idx:shape -> sh:shape -> 'a list -> 'a
-
-  (** [transpose_2d t] returns a copy of the tensor [t] with its two last
-      dimension exchanged.*)
-
-  val transpose_2d : ('a, 'b) t -> ('a, 'b) t
-
-  (** [unsqueeze sh1 sh2] returns the lowest common shape between [sh1] and
-      [sh2], and None if there is no common shape. A common shape is when a
-      shape of higher dimension has only 1 coordinates on non-shared dimensions
-      with the other. *)
-
-  val unsqueeze : sh1:shape -> sh2:shape -> shape option
-end
-
-(** {1 Modules for graph generation} *)
-
-module Node : sig
-  type shape = int array
-
-  type operator =
-    | Add
-    | Sub
-    | Mul
-    | Div
-    | Matmul
-    | Gemm
-    | LogSoftmax
-    | ReLu
-    | Transpose
-    | Squeeze
-    | MaxPool
-    | Conv
-    | Reshape
-    | Flatten
-    | Identity
-    | Constant
-    | NO_OP
-    | RW_Linearized_ReLu
-    | Gather
-    | ReduceSum
-    | GatherND
-    | RandomNormal
-    | Abs
-    | Log
-
-  (** Type describing the different operations handled. Those operations are
-      inspired by those defined in the ONNX documentation.
-
-      @see <https://github.com/onnx/onnx/blob/master/docs/Operators.md>
-        for more informations. They are to be coupled with the relevant
-        operators parameters. *)
-
-  val str_op : operator -> string
-  val show_shape : shape -> string
-
-  type ksize = Ksize of shape
-  type stride = Stride of shape
-  type pads = Pads of shape
-  type dilations = Dilations of shape
-
-  type operator_parameters =
-    | Pool_params of (ksize * stride option * pads option * dilations option)
-    | Conv_params of (ksize * stride option * pads option * dilations option)
-    | Transpose_params of shape
-    | RW_Linearized_ReLu_params of
-        (bool list list * ((string, float) Base.Hashtbl.t list * int))
-    | Gather_params of int
-    | ReduceSum_params of int * int
-    | RandomNormal_params of int * float * float * float * shape
-
-  val str_op_params : operator_parameters -> string
-
-  type ('a, 'b) t = {
-    id : int;
-    name : string option;
-    shape : shape;
-    operator : operator;
-    operator_parameters : operator_parameters option;
-    pred : string list;
-    succ : string list;
-    tensor : ('a, 'b) Tensor.t option;
-  }
-  (** Type encapsulating parameters for operations. For Convolutions and
-      Pooling, kernel size, padding, strides For Transpose, shape *)
-
-  val compare : ('a, 'b) t -> ('a, 'b) t -> int
-  val hash : ('a, 'b) t -> int
-  val equal : ('a, 'b) t -> ('a, 'b) t -> bool
-
-  val create :
-    id:int ->
-    name:string option ->
-    sh:shape ->
-    op:operator ->
-    op_p:operator_parameters option ->
-    pred:string list ->
-    succ:string list ->
-    tensor:('a, 'b) Tensor.t option ->
-    ('a, 'b) t
-
-  val get_name : ('a, 'b) t -> string
-  val get_shape : ('a, 'b) t -> shape
-  val get_op : ('a, 'b) t -> operator
-  val get_pred_list : ('a, 'b) t -> string list
-  val get_succ_list : ('a, 'b) t -> string list
-  val get_tensor : ('a, 'b) t -> ('a, 'b) Tensor.t option
-  val is_data_node : ('a, 'b) t -> bool
-  val is_input_node : ('a, 'b) t -> bool
-  val is_output_node : ('a, 'b) t -> bool
-  val num_neurons : ('a, 'b) t -> int
-  val show : ('a, 'b) t -> ('a -> string) -> string
-end
-
-module type VInput = sig
-  type l
-  type r
-
-  val convert_f : l -> string
-end
-
-module MakeVertex (I : VInput) : sig
-  include Graph.Sig.VERTEX with type t = (I.l, I.r) Node.t
-end
-
-module Edge : sig
-  type t = string
-
-  val compare : 'a -> 'a -> int
-  val equal : 'a -> 'a -> bool
-  val default : t
-end
-
-(** NIER is a graph {b (V,E)} where {b V} is the set of vertices (nodes) and
-    {b E} is the set of edges (connections between nodes). Nodes contains the
-    following informations:
-
-    - unique id
-    - name coming from the original model, if it exists
-    - shape of the tensor resulting from the application of the node operation,
-      if it exist
-    - operation performed
-    - parameters of the operation
-    - an optional tensor storing the data
-
-    Note that tensor have their own shape; they must be equal to the NIER's node
-    shape however. *)
-
-module NierCFG (I : VInput) : sig
-  include
-    Graph.Sig.I
-      with type V.t = MakeVertex(I).t
-       and type V.label = MakeVertex(I).t
-       and type E.t = MakeVertex(I).t * Edge.t * MakeVertex(I).t
-       and type E.label = Edge.t
-
-  val init_cfg : t
-  val vertex_list : t -> vertex list
-  val preds : t -> vertex -> vertex list
-
-  (** [preds_names g v] returns a list of names of predecessors nodes *)
-
-  val preds_names : t -> vertex -> string list
-  val succs : t -> vertex -> vertex list
-
-  (** [succs_names g v] returns a list of names of predecessors nodes *)
-
-  val succs_names : t -> vertex -> string list
-
-  (** [input_node g] returns the nodes considered as describing the inputs of
-      the neural network. *)
-
-  val input_nodes : t -> vertex list
-  val find_vertices : t -> (vertex -> bool) -> vertex list
-end
-
-module NierCFGFloat : sig
-  include
-    Graph.Sig.I
-      with type V.t = (float, Bigarray.float64_elt) Node.t
-       and type V.label = (float, Bigarray.float64_elt) Node.t
-       and type E.t =
-        (float, Bigarray.float64_elt) Node.t
-        * Edge.t
-        * (float, Bigarray.float64_elt) Node.t
-       and type E.label = Edge.t
-
-  val init_cfg : t
-  val vertex_list : t -> vertex list
-  val preds : t -> vertex -> vertex list
-
-  (** [preds_names g v] returns a list of names of predecessors nodes *)
-
-  val preds_names : t -> vertex -> string list
-  val succs : t -> vertex -> vertex list
-
-  (** [succs_names g v] returns a list of names of predecessors nodes *)
-
-  val succs_names : t -> vertex -> string list
-
-  (** [input_node g] returns the nodes considered as describing the inputs of
-      the neural network. *)
-
-  val input_nodes : t -> vertex list
-  val find_vertices : t -> (vertex -> bool) -> vertex list
-
-  (** [data_node_of n ] returns one node containing a tensor * data among the
-      predecessors of [n]*)
-
-  val data_node_of : vertex -> t -> vertex option
-
-  (** [infer_shape g n sh o_b] returns the inferred shape of the output of node
-      [n] in NIER [g] with input shape [sh]. Shape inference is made using the
-      node operator and its predecessors shapes. [o_b] is true when performing
-      backward propagation, to choose which matrix size to consider. *)
-
-  val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape
-end
-
-(** {1 Pretty printers} *)
-
-val print_cfg_graph : NierCFGFloat.t -> unit
-val out_cfg_graph : NierCFGFloat.t -> unit
diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml
deleted file mode 100644
index c179160172e25e44e97ef583e3918a98324f7aff..0000000000000000000000000000000000000000
--- a/lib/ir/nier_simple.ml
+++ /dev/null
@@ -1,695 +0,0 @@
-open Base
-
-module Shape : sig
-  type t [@@deriving show, ord, eq]
-
-  val to_array : t -> int array
-  val of_array : int array -> t
-  val to_list : t -> int list
-  val of_list : int list -> t
-  val rank : t -> int
-  val size : t -> int
-  val get : t -> int -> int
-  val set : t -> int -> int -> t
-  val to_array_unsafe : t -> int array
-  val row_major : t -> int array -> int
-  val unrow_major : t -> int -> int array
-end = struct
-  type t = int array [@@deriving ord, eq]
-
-  let to_array = Array.copy
-  let to_array_unsafe x = x
-  let of_array = Array.copy
-  let to_list = Array.to_list
-  let of_list = Array.of_list
-  let rank = Array.length
-  let size t = Array.fold t ~f:( * ) ~init:1
-  let pp fmt x = Fmt.pf fmt "[%a]" Fmt.(array ~sep:semi int) x
-  let show s = Fmt.str "%a" pp s
-  let get = Array.get
-
-  let set t k v =
-    let t = Array.copy t in
-    Array.set t k v;
-    t
-
-  let row_major t a =
-    assert (Array.length t = Array.length a);
-    let r = ref 0 in
-    for i = 0 to Array.length t - 1 do
-      r := (!r * t.(i)) + a.(i)
-    done;
-    !r
-
-  let unrow_major t i =
-    let r = ref i in
-    let a = Array.create ~len:(Array.length t) 0 in
-    for i = Array.length t - 1 downto 0 do
-      a.(i) <- !r % t.(i);
-      r := !r / t.(i)
-    done;
-    a
-end
-
-module Tensor : sig
-  type ('a, 'b) t
-
-  val of_tensor : ('a, 'b) Nier_cfg.Tensor.t -> ('a, 'b) t
-  val to_tensor : ('a, 'b) t -> ('a, 'b) Nier_cfg.Tensor.t
-  val create_1_float : float -> (float, Bigarray.float64_elt) t
-  val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t
-  val shape : ('a, 'b) t -> Shape.t
-  val flatten : ('a, 'b) t -> 'a list
-
-  val of_array1 :
-    Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t
-
-  val get : ('a, 'b) t -> int array -> 'a
-end = struct
-  type ('a, 'b) t = ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t
-
-  let copy t =
-    let t' = Bigarray.Genarray.(create (kind t) Bigarray.c_layout (dims t)) in
-    Bigarray.Genarray.blit t t';
-    t'
-
-  let of_tensor = copy
-  let to_tensor = copy
-
-  let create_1_float v =
-    let t =
-      Bigarray.Genarray.(create Bigarray.float64 Bigarray.c_layout [| 1 |])
-    in
-    Bigarray.Genarray.set t [| 0 |] v;
-    t
-
-  let create_1_int64 v =
-    let t =
-      Bigarray.Genarray.(create Bigarray.int64 Bigarray.c_layout [| 1 |])
-    in
-    Bigarray.Genarray.set t [| 0 |] v;
-    t
-
-  let shape x = Shape.of_array @@ Nier_cfg.Tensor.get_shape x
-  let flatten = Nier_cfg.Tensor.flatten
-
-  let of_array1 shape t =
-    Bigarray.reshape
-      (copy @@ Bigarray.genarray_of_array1 t)
-      (Shape.to_array_unsafe shape)
-
-  let get = Bigarray.Genarray.get
-end
-
-module GenTensor = struct
-  type t =
-    | Float of (float, Bigarray.float64_elt) Tensor.t
-    | Int64 of (int64, Bigarray.int64_elt) Tensor.t
-
-  let create_1_float f = Float (Tensor.create_1_float f)
-  let create_1_int64 i = Int64 (Tensor.create_1_int64 i)
-
-  let of_int_array a =
-    Int64
-      (Tensor.of_array1
-         (Shape.of_array [| Array.length a |])
-         (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int)))
-
-  let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i
-end
-
-(** TODO: add the information needed to compute the shape *)
-type descr =
-  | Constant of { data : GenTensor.t }
-  | Add of {
-      input1 : node;
-      input2 : node;
-    }
-  | Sub of {
-      input1 : node;
-      input2 : node;
-    }
-  | Mul of {
-      input1 : node;
-      input2 : node;
-    }
-  | Div of {
-      input1 : node;
-      input2 : node;
-    }
-  | Matmul of {
-      input1 : node;
-      input2 : node;
-    }
-  | Gemm of {
-      inputA : node;
-      inputB : node;
-      inputC : node option;
-      alpha : float;
-      beta : float;
-      transA : bool;
-      transB : bool;
-    }
-  | LogSoftmax
-  | ReLu of { input : node }
-  | Transpose of {
-      input : node;
-        (* called "data" in ONNX documentation :
-           https://onnx.ai/onnx/operators/onnx__Transpose.html*)
-      perm : int list;
-    }
-  | Squeeze of {
-      data : node;
-      axes : node option; (* data int64 *)
-    }
-  | MaxPool
-  | Conv
-  | Reshape of {
-      input : node;
-      shape : node; (* data int64 *)
-    }
-  | Flatten of {
-      input : node;
-      axis : int;
-    }
-  | Identity of { input : node }
-  | Input of { shape : Shape.t }
-  | RW_Linearized_ReLu
-  | Concat of {
-      inputs : node list;
-      axis : int;
-    }
-  | Gather of {
-      input : node;
-      indices : node;
-      axis : int;
-    }
-  | ReduceSum of {
-      input : node;
-      axes : node option;
-      keepdims : int;
-      noop_with_empty_axes : int;
-    }
-  | GatherND of {
-      data : node;
-      indices : node;
-      batch_dims : int;
-    }
-  | RandomNormal of {
-      dtype : int;
-      mean : float;
-      scale : float;
-      seed : float;
-      shape : int array;
-    }
-  | Abs of { input : node }
-  | Log of { input : node }
-
-and node = {
-  id : int;
-  descr : descr;
-  shape : Shape.t;
-  ty : ty;
-}
-
-and ty =
-  | Float
-  | Int64
-
-let pp_descr fmt p =
-  match p with
-  | Input { shape } -> Fmt.pf fmt "Input: %a" Shape.pp shape
-  | Transpose { perm; _ } ->
-    Fmt.pf fmt "Transpose: [%a]" Fmt.(list ~sep:semi int) perm
-  | Constant { data = Int64 b } when Shape.size (Tensor.shape b) < 3 ->
-    Fmt.pf fmt "Constant[%a]" Fmt.(list ~sep:comma int64) (Tensor.flatten b)
-  | Constant _ -> Fmt.pf fmt "Constant"
-  | Add _ -> Fmt.pf fmt "Add"
-  | Sub _ -> Fmt.pf fmt "Sub"
-  | Mul _ -> Fmt.pf fmt "Mul"
-  | Div _ -> Fmt.pf fmt "Div"
-  | Matmul _ -> Fmt.pf fmt "Matmul"
-  | Gemm _ -> Fmt.pf fmt "Gemm"
-  | LogSoftmax -> Fmt.pf fmt "LogSoftmax"
-  | ReLu _ -> Fmt.pf fmt "ReLu"
-  | Squeeze _ -> Fmt.pf fmt "Squeeze"
-  | MaxPool -> Fmt.pf fmt "MaxPool"
-  | Conv -> Fmt.pf fmt "Conv"
-  | Reshape _ -> Fmt.pf fmt "Reshape"
-  | Flatten _ -> Fmt.pf fmt "Flatten"
-  | Identity _ -> Fmt.pf fmt "Identity"
-  | RW_Linearized_ReLu -> Fmt.pf fmt "RW_Linearized_ReLu"
-  | Concat { axis; _ } -> Fmt.pf fmt "Concat[%i]" axis
-  | Gather _ -> Fmt.pf fmt "Gather"
-  | ReduceSum _ -> Fmt.pf fmt "ReduceSum"
-  | GatherND _ -> Fmt.pf fmt "GatherND"
-  | RandomNormal _ -> Fmt.pf fmt "RandomNormal"
-  | Abs _ -> Fmt.pf fmt "Abs"
-  | Log _ -> Fmt.pf fmt "Log"
-
-type t = {
-  output : node;
-  succs : (int, node list) Base.Hashtbl.t;
-}
-
-let output t = t.output
-
-module Node = struct
-  type t = node
-
-  let compare { id = id1; _ } { id = id2; _ } = Int.compare id1 id2
-  let equal { id = id1; _ } { id = id2; _ } = Int.equal id1 id2
-  let hash { id; _ } = id
-  let sexp_of_t node = Base.Int.sexp_of_t node.id
-  let pp fmt n = Fmt.pf fmt "@[%i: %a@]" n.id pp_descr n.descr
-  let show n = Fmt.str "%a" pp n
-
-  include Comparator.Make (struct
-    type nonrec t = t
-
-    let compare = compare
-    let sexp_of_t = sexp_of_t
-  end)
-
-  let rec compute_shape n = n.shape
-
-  and compute_shape_descr = function
-    | Add { input1; _ }
-    | Div { input1; _ }
-    | Mul { input1; _ }
-    | Sub { input1; _ } ->
-      compute_shape input1
-    | Flatten { input; axis } ->
-      (* (d_0 X d_1 … d_(axis-1), d_axis X d_(axis+1) … X dn). *)
-      let shape = compute_shape input in
-      let d1 = ref 1 in
-      let d2 = ref 1 in
-      for i = 0 to axis - 1 do
-        d1 := !d1 * Shape.get shape i
-      done;
-      for i = axis to Shape.rank shape - 1 do
-        d2 := !d1 * Shape.get shape i
-      done;
-      Shape.of_list [ !d1; !d2 ]
-    | Input { shape } -> shape
-    | ReLu { input } -> compute_shape input
-    | Transpose { input; perm = [] } ->
-      compute_shape input |> Shape.to_list |> List.rev |> Shape.of_list
-    | Transpose { input; perm } ->
-      let shape = compute_shape input in
-      let rank = Shape.rank shape in
-      assert (Int.equal rank (List.length perm));
-      let shape' = Array.create ~len:rank 0 in
-      let shape = Shape.to_array shape in
-      Base.List.iteri perm ~f:(fun i j ->
-        Array.set shape' i (Array.get shape j));
-      Shape.of_array shape'
-    | Constant { data } -> GenTensor.shape data
-    | Concat { inputs; axis } ->
-      let shapes = List.map ~f:compute_shape inputs in
-      let shape = List.hd_exn shapes in
-      let axis = if axis < 0 then Shape.rank shape + axis else axis in
-      let l = List.map ~f:(fun s -> Shape.get s axis) shapes in
-      let i = List.reduce_exn ~f:( + ) l in
-      Shape.set shape axis i
-    | Gather { input; indices; axis } -> (
-      let input_shape = compute_shape input in
-      let indices_shape = compute_shape indices in
-      let axis = if axis < 0 then Shape.rank input_shape + axis else axis in
-      match List.split_n (Shape.to_list input_shape) axis with
-      | _, [] -> failwith "axis is bigger than shape rank"
-      | before, _ :: after ->
-        Shape.of_list (before @ Shape.to_list indices_shape @ after))
-    | Matmul { input1; input2 } ->
-      let pad_left = function
-        | [] -> failwith "Impossible to pad empty shape"
-        | [ a ] -> [ 1; a ]
-        | x -> x
-      in
-      let pad_right = function
-        | [] -> failwith "Impossible to pad empty shape"
-        | [ a ] -> [ a; 1 ]
-        | x -> x
-      in
-      let rec one_padding l i =
-        if i <= 0 then l else one_padding (1 :: l) (i - 1)
-      in
-      (* Expected semantic:
-       * Matrix multiplication C = AB
-       * A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
-       * shape of b: b_sh
-       * shape of a: a_sh
-       * shape of c: c_sh
-       *)
-      let check_matmul_size_ab ~a_sh ~b_sh =
-        let adim2 = pad_left a_sh in
-        let bdim2 = pad_right b_sh in
-        let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in
-        let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in
-        let rec infer_csize acc ad bd =
-          match (ad, bd) with
-          | [ m; n ], [ nn; p ] ->
-            if nn = n
-            then List.rev_append acc [ m; p ]
-            else failwith "size of matrices not adequate"
-          | a :: la, b :: lb ->
-            if a = b
-            then infer_csize (a :: acc) la lb
-            else if a = 1
-            then infer_csize (b :: acc) la lb
-            else if b = 1
-            then infer_csize (a :: acc) la lb
-            else failwith "Checking matmul_size failed: one discordance"
-          | _, _ -> failwith "Checking matmul_size failed"
-        in
-        infer_csize [] adim bdim
-      in
-      (* TODO: in case of pad_left and pad_right remove the added dimension. But
-         it is not clear what must be done when broadcasting is done *)
-      Shape.of_list
-        (check_matmul_size_ab
-           ~a_sh:(Shape.to_list (compute_shape input1))
-           ~b_sh:(Shape.to_list (compute_shape input2)))
-    | Reshape { input; shape; _ } ->
-      let shape =
-        match shape.descr with
-        | Constant { data = Int64 a } ->
-          List.map ~f:Int64.to_int_exn (Tensor.flatten a)
-        | _ ->
-          (* Some constant propagation could be useful in some cases eg. patch-1
-             VNNcomp *)
-          failwith "non-constant shape in reshape not supported"
-      in
-      List.iter shape ~f:(function
-        | -1 | 0 -> failwith "not implemented 0 -1 in shape for reshape"
-        | _ -> ());
-      let out = Shape.of_list shape in
-      if Shape.size out <> Shape.size input.shape
-      then
-        failwith
-          "Reshape: shape of input and shape given have not the same number of \
-           elements";
-      out
-    | Gemm { inputA; inputB; inputC = _; alpha = _; beta = _; transA; transB }
-      ->
-      let rank2 i =
-        match Shape.to_array_unsafe i.shape with
-        | [| k; n |] -> (k, n)
-        | _ -> failwith "Gemm input must be of size 2"
-      in
-      let tr trans (k, n) = if trans then (n, k) else (k, n) in
-      let a1, a2 = tr transA @@ rank2 inputA in
-      let b1, b2 = tr transB @@ rank2 inputB in
-      if not (Int.equal a2 b1)
-      then
-        Fmt.failwith "Gemm (M:%i,K:%i) (K:%i,N:%i) -> (M:%i,N:%i)" a1 a2 b1 b2
-          a1 b2;
-      Shape.of_array [| a1; b2 |]
-    | ( LogSoftmax | Squeeze _ | MaxPool | Conv | Identity _
-      | RW_Linearized_ReLu | ReduceSum _ | GatherND _ | RandomNormal _ | Abs _
-      | Log _ ) as n ->
-      failwith (Fmt.str "todo compute shape : %a" pp_descr n)
-
-  let compute_ty : _ -> ty = function
-    | Constant { data = Float _ } -> Float
-    | Constant { data = Int64 _ } -> Int64
-    | _ -> Float
-
-  let create =
-    let c = ref (-1) in
-    fun descr ->
-      Int.incr c;
-      {
-        id = !c;
-        descr;
-        shape = compute_shape_descr descr;
-        ty = compute_ty descr;
-      }
-
-  let constant_int_array a =
-    create (Constant { data = GenTensor.of_int_array a })
-
-  let reshape shape node =
-    if Shape.equal node.shape shape
-    then node
-    else
-      create
-        (Reshape
-           { input = node; shape = constant_int_array (Shape.to_array shape) })
-
-  let gather_int_as_matmul input i =
-    let input1 =
-      reshape (Shape.of_array [| 1; Shape.size input.shape |]) input
-    in
-    let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in
-    Array.set selector i Float.one;
-    let selector =
-      GenTensor.Float
-        (Tensor.of_array1
-           (Shape.of_array [| Array.length selector; 1 |])
-           (Bigarray.Array1.of_array Float64 C_layout selector))
-    in
-    let input2 = create (Constant { data = selector }) in
-    let result = create (Matmul { input1; input2 }) in
-    reshape (Shape.of_array [| 1 |]) result
-
-  let gather_int ?(encode = true) input i =
-    if encode
-    then gather_int_as_matmul input i
-    else
-      let indices =
-        create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) })
-      in
-      create (Gather { input; indices; axis = 0 })
-
-  let mul_float input f =
-    let input1 = reshape (Shape.of_array [| 1; 1 |]) input in
-    let f = Array.create ~len:1 f in
-    let f =
-      GenTensor.Float
-        (Tensor.of_array1
-           (Shape.of_array [| Array.length f; 1 |])
-           (Bigarray.Array1.of_array Float64 C_layout f))
-    in
-    let input2 = create (Constant { data = f }) in
-    let result = create (Matmul { input1; input2 }) in
-    reshape (Shape.of_array [| 1 |]) result
-
-  let div_float ?(encode = true) input f =
-    if encode
-    then
-      let f = Float.one /. f in
-      mul_float input f
-    else
-      let input1 = reshape (Shape.of_array [| 1; 1 |]) input in
-      let f = Array.create ~len:1 f in
-      let f =
-        GenTensor.Float
-          (Tensor.of_array1
-             (Shape.of_array [| Array.length f; 1 |])
-             (Bigarray.Array1.of_array Float64 C_layout f))
-      in
-      let input2 = create (Constant { data = f }) in
-      let result = create (Div { input1; input2 }) in
-      reshape (Shape.of_array [| 1 |]) result
-
-  let concat_0 = function
-    | [ n ] -> n
-    | [] -> failwith "empty concat"
-    | inputs -> create (Concat { inputs; axis = 0 })
-
-  let preds node =
-    match node.descr with
-    | Constant _ | Input _ -> []
-    | Add { input1; input2 }
-    | Sub { input1; input2 }
-    | Mul { input1; input2 }
-    | Div { input1; input2 }
-    | Matmul { input1; input2 } ->
-      [ input1; input2 ]
-    | Gather { input; indices; axis = _ } -> [ input; indices ]
-    | GatherND { data; indices; batch_dims = _ } -> [ data; indices ]
-    | ReLu { input } | Abs { input } | Log { input } -> [ input ]
-    | Concat { inputs; axis = _ } -> inputs
-    | ReduceSum { input; axes = Some x; _ } -> [ input; x ]
-    | ReduceSum { input; axes = None; _ } -> [ input ]
-    | RandomNormal _ -> []
-    | Transpose { input; _ } -> [ input ]
-    | Flatten { input; _ } -> [ input ]
-    | Identity { input } -> [ input ]
-    | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ]
-    | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ]
-    | Squeeze { data; _ } -> [ data ]
-    | Reshape { input; shape; _ } -> [ input; shape ]
-    | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> []
-
-  let map f n =
-    match n.descr with
-    | Constant _ | Input _ -> n
-    | Add { input1; input2 } ->
-      create (Add { input1 = f input1; input2 = f input2 })
-    | Sub { input1; input2 } ->
-      create (Sub { input1 = f input1; input2 = f input2 })
-    | Mul { input1; input2 } ->
-      create (Mul { input1 = f input1; input2 = f input2 })
-    | Div { input1; input2 } ->
-      create (Div { input1 = f input1; input2 = f input2 })
-    | Matmul { input1; input2 } ->
-      create (Matmul { input1 = f input1; input2 = f input2 })
-    | ReLu { input } -> create (ReLu { input = f input })
-    | Abs { input } -> create (Abs { input = f input })
-    | Log { input } -> create (Log { input = f input })
-    | RandomNormal _ as descr -> create descr
-    | ReduceSum { input; axes; keepdims; noop_with_empty_axes } ->
-      create
-        (ReduceSum { input = f input; axes; keepdims; noop_with_empty_axes })
-    | Gather { input; indices; axis } ->
-      create (Gather { input = f input; indices = f indices; axis })
-    | GatherND { data; indices; batch_dims } ->
-      create (GatherND { data = f data; indices = f indices; batch_dims })
-    | Transpose t -> create (Transpose { t with input = f t.input })
-    | Flatten t -> create (Flatten { t with input = f t.input })
-    | Identity { input } -> create (Identity { input = f input })
-    | Concat { inputs; axis } ->
-      create (Concat { inputs = List.map ~f inputs; axis })
-    | Gemm t ->
-      create
-        (Gemm
-           {
-             t with
-             inputA = f t.inputA;
-             inputB = f t.inputB;
-             inputC = Base.Option.map t.inputC ~f;
-           })
-    | Squeeze t -> create (Squeeze { t with data = f t.data })
-    | Reshape t -> create (Reshape { t with input = f t.input })
-    | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> n (* todo *)
-
-  (* let map_rec f node = let h = Base.Hashtbl.create (module Base.Int) in let
-     rec aux n = Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux
-     n)) in aux node *)
-
-  let replace_input f node =
-    let h = Base.Hashtbl.create (module Base.Int) in
-    let rec aux n =
-      Base.Hashtbl.find_or_add h n.id ~default:(fun () ->
-        match n.descr with Input _ -> f () | _ -> map aux n)
-    in
-    aux node
-
-  (* iter on the nodes accessible from [node] ([node] comprised) without
-     repetition *)
-  let map_rec f node =
-    let h = Base.Hashtbl.create (module Base.Int) in
-    let rec aux n =
-      Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n))
-    in
-    aux node
-
-  let iter_rec f node =
-    let h = Base.Hashtbl.create (module Base.Int) in
-    let rec aux n =
-      Base.Hashtbl.find_or_add h n.id ~default:(fun () ->
-        List.iter ~f:aux (preds n);
-        f n)
-    in
-    aux node
-end
-
-(** TODO: some other invariants must be checked e.g only one input *)
-let create output =
-  let succs = Base.Hashtbl.create (module Base.Int) in
-  let check_node node =
-    List.iter
-      ~f:(fun p -> Base.Hashtbl.add_multi succs ~key:p.id ~data:node)
-      (Node.preds node)
-  in
-  (* Add the key of the output nodes so that all the nodes are in succs *)
-  Base.Hashtbl.add_exn succs ~key:output.id ~data:[];
-  Node.iter_rec check_node output;
-  { output; succs }
-
-let input_shape g =
-  let r = ref None in
-  let check_node n =
-    match n.descr with Input { shape } -> r := Some shape | _ -> ()
-  in
-  Node.iter_rec check_node g.output;
-  Option.value_exn !r
-
-let succs t node = Base.Hashtbl.find_exn t.succs node.id
-let iter_vertex f t = Node.iter_rec f t.output
-let iter_succ f t node = List.iter ~f (Base.Hashtbl.find_multi t.succs node.id)
-let pp fmt t = iter_vertex (fun v -> Fmt.pf fmt "@[%a@]@ " pp_descr v.descr) t
-
-let pp_debug fmt t =
-  iter_vertex
-    (fun v ->
-      Fmt.pf fmt "@[%i: %a(%a) : %a@]@ " v.id pp_descr v.descr
-        Fmt.(list ~sep:comma (using (fun x -> x.id) int))
-        (Node.preds v) Shape.pp v.shape)
-    t
-
-let nodes t =
-  let l = ref [] in
-  iter_vertex (fun v -> l := v :: !l) t;
-  !l
-
-module M = Graph.Topological.Make
-
-module GFloat = struct
-  type nonrec t = t
-
-  let iter_edges_e f t =
-    iter_vertex (fun n -> List.iter ~f:(fun n' -> f (n', n)) (Node.preds n)) t
-
-  module Node = struct
-    type t = Node.t
-
-    let compare = Node.compare
-    let equal = Node.equal
-    let hash = Node.hash
-    let sexp_of_t = Node.sexp_of_t
-    let create = Node.create
-    let show = Node.show
-  end
-
-  module V = Node
-
-  module E = struct
-    type t = V.t * V.t
-
-    let src = fst
-    let dst = snd
-  end
-
-  let iter_vertex = iter_vertex
-  let iter_succ = iter_succ
-end
-
-module Dot = Graph.Graphviz.Dot (struct
-  include GFloat (* use the graph module from above *)
-
-  let node_label (v : Node.t) = Node.show v
-  let edge_attributes (_, _) = []
-  let default_edge_attributes _ = []
-  let get_subgraph _ = None
-  let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ]
-  let vertex_name (v : Node.t) = Int.to_string v.id
-  let default_vertex_attributes _ = []
-  let graph_attributes _ = []
-end)
-
-let grapheasy g =
-  try
-    let cin, cout =
-      Unix.open_process_args "graph-easy"
-        [| "graph-easy"; "--from=graphviz"; "--as=boxart" |]
-    in
-    Dot.output_graph cout g;
-    Stdlib.close_out cout;
-    let ascii = Stdio.In_channel.input_all cin in
-    ignore (Unix.close_process (cin, cout));
-    ascii
-  with exn ->
-    Fmt.str "Error graph-easy call: %s" (Stdlib.Printexc.to_string exn)
diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli
deleted file mode 100644
index d1a6d322c5af7ffdd7e237f928a0cf974b2b8aff..0000000000000000000000000000000000000000
--- a/lib/ir/nier_simple.mli
+++ /dev/null
@@ -1,251 +0,0 @@
-module Shape : sig
-  type t [@@deriving show, ord, eq]
-
-  val to_array : t -> int array
-  val of_array : int array -> t
-  val to_list : t -> int list
-  val of_list : int list -> t
-  val rank : t -> int
-  val size : t -> int
-  val get : t -> int -> int
-  val set : t -> int -> int -> t
-  val row_major : t -> int array -> int
-  val unrow_major : t -> int -> int array
-end
-
-module Tensor : sig
-  (** Immutable tensors *)
-
-  type ('a, 'b) t
-
-  val of_tensor : ('a, 'b) Nier_cfg.Tensor.t -> ('a, 'b) t
-  val to_tensor : ('a, 'b) t -> ('a, 'b) Nier_cfg.Tensor.t
-  val create_1_float : float -> (float, Bigarray.float64_elt) t
-  val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t
-  val shape : ('a, 'b) t -> Shape.t
-  val flatten : ('a, 'b) t -> 'a list
-
-  val of_array1 :
-    Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t
-
-  val get : ('a, 'b) t -> int array -> 'a
-end
-
-module GenTensor : sig
-  type t =
-    | Float of (float, Bigarray.float64_elt) Tensor.t
-    | Int64 of (int64, Bigarray.int64_elt) Tensor.t
-
-  val create_1_float : float -> t
-  val create_1_int64 : int64 -> t
-  val shape : t -> Shape.t
-  val of_int_array : int array -> t
-end
-
-type descr =
-  | Constant of { data : GenTensor.t }
-  | Add of {
-      input1 : node;
-      input2 : node;
-    }
-  | Sub of {
-      input1 : node;
-      input2 : node;
-    }
-  | Mul of {
-      input1 : node;
-      input2 : node;
-    }
-  | Div of {
-      input1 : node;
-      input2 : node;
-    }
-  | Matmul of {
-      input1 : node;
-      input2 : node;
-    }
-  | Gemm of {
-      inputA : node;
-      inputB : node;
-      inputC : node option;
-      alpha : float;
-      beta : float;
-      transA : bool;
-      transB : bool;
-    }
-  | LogSoftmax
-  | ReLu of { input : node }
-  | Transpose of {
-      input : node;
-        (* called "data" in ONNX documentation :
-           https://onnx.ai/onnx/operators/onnx__Transpose.html*)
-      perm : int list;
-    }
-  | Squeeze of {
-      data : node;
-      axes : node option; (* int64 *)
-    }
-  | MaxPool
-  | Conv
-  | Reshape of {
-      input : node;
-      shape : node; (* int64 *)
-    }
-  | Flatten of {
-      input : node;
-      axis : int;
-    }
-  | Identity of { input : node }
-  | Input of { shape : Shape.t }
-  | RW_Linearized_ReLu
-  | Concat of {
-      inputs : node list;
-      axis : Base.int;
-    }
-  | Gather of {
-      input : node;
-      indices : node;
-      axis : int;
-    }
-  | ReduceSum of {
-      input : node;
-      axes : node option;
-      keepdims : int;
-      noop_with_empty_axes : int;
-    }
-  | GatherND of {
-      data : node;
-      indices : node;
-      batch_dims : int;
-    }
-  | RandomNormal of {
-      dtype : int;
-      mean : float;
-      scale : float;
-      seed : float;
-      shape : int array;
-    }
-  | Abs of { input : node }
-  | Log of { input : node }
-
-and node = private {
-  id : int; (* unique identifier *)
-  descr : descr;
-  shape : Shape.t;
-  ty : ty;
-}
-
-and ty =
-  | Float
-  | Int64
-
-module Node : sig
-  type t = node [@@deriving show]
-
-  val equal : t -> t -> bool
-
-  include Base.Hashtbl.Key.S with type t := t
-  include Base.Comparator.S with type t := t
-
-  val create : descr -> t
-
-  val gather_int : ?encode:bool -> t -> int -> t
-  (** create a node by selection at a given index. *)
-  (* Implemented via a [Matmul] if [encode] (true by default).
-
-     TODO: [encode] should be not be a parameter, rather depend on prover. *)
-
-  val mul_float : t -> float -> t
-  (* Implemented via a [Matmul]. *)
-
-  val div_float : ?encode:bool -> t -> float -> t
-  (* Implemented via a [Matmul] if [encode] (true by default).
-
-     TODO: [encode] should be not be a parameter, rather depend on prover. *)
-
-  val constant_int_array : int array -> t
-  (** create a node for a constant array *)
-
-  val reshape : Shape.t -> t -> t
-  (** create if necessary a reshape node *)
-
-  val concat_0 : t list -> t
-  (** create if necessary a concat node for the first axis *)
-
-  val map : (t -> t) -> t -> t
-  (** [map f n] replace the direct inputs [i] of n by [f i] *)
-
-  val map_rec : (t -> t) -> t -> t
-  (** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by
-      [f i] *)
-
-  val replace_input : (unit -> t) -> t -> t
-  (** [replace_input f n] replace the input in [n] by [f ()] *)
-
-  val preds : t -> t list
-  (** Direct predecessors of a node *)
-
-  val iter_rec : (t -> unit) -> t -> unit
-  (** Iterate on the predecessors of a node and itself. Repect topological
-      order. *)
-
-  val compute_shape : t -> Shape.t
-end
-
-type t
-
-val pp : t Fmt.t
-val pp_debug : t Fmt.t
-
-val create : node -> t
-(** Create a network from its output node, it must have only one input *)
-
-val output : t -> node
-(** Output node of the network *)
-
-val input_shape : t -> Shape.t
-(** Input shape of the network *)
-
-val succs : t -> node -> node list
-(** successors of a node *)
-
-val iter_vertex : (Node.t -> unit) -> t -> unit
-val iter_succ : (Node.t -> unit) -> t -> Node.t -> unit
-val nodes : t -> Node.t list
-
-(** Respect some OcamlGraph signature *)
-module GFloat : sig
-  type nonrec t = t
-
-  module Node : sig
-    type t = node
-
-    val equal : t -> t -> bool
-    val compare : t -> t -> int
-    val hash : t -> int
-    val create : descr -> t
-    val sexp_of_t : t -> Sexplib0.Sexp.t
-  end
-
-  module V = Node
-
-  module E : sig
-    type t = V.t * V.t
-
-    val src : t -> V.t
-    val dst : t -> V.t
-  end
-
-  val iter_vertex : (V.t -> unit) -> t -> unit
-  val iter_succ : (V.t -> unit) -> t -> V.t -> unit
-  val iter_edges_e : (E.t -> unit) -> t -> unit
-end
-
-module Dot : sig
-  val fprint_graph : Format.formatter -> GFloat.t -> unit
-  val output_graph : out_channel -> GFloat.t -> unit
-end
-
-val grapheasy : t -> string
-(** @return
-      ASCII representation of the graph using "graph-easy" external command *)
diff --git a/lib/ir/dune b/lib/nir/dune
similarity index 88%
rename from lib/ir/dune
rename to lib/nir/dune
index 14a9bd69afbf4d9e5a36e6dc4b88b4ef669ddcc7..7a7030df19a39b1f7a548ed832a1ce43d79aa9b0 100644
--- a/lib/ir/dune
+++ b/lib/nir/dune
@@ -1,6 +1,6 @@
 (library
- (name ir)
- (public_name caisar.ir)
+ (name nir)
+ (public_name caisar.nir)
  (preprocess
   (pps
    ppx_inline_test
diff --git a/lib/nir/gentensor.ml b/lib/nir/gentensor.ml
new file mode 100644
index 0000000000000000000000000000000000000000..5a2eae003d00794fab5a6b7283f1692a7a177f21
--- /dev/null
+++ b/lib/nir/gentensor.ml
@@ -0,0 +1,37 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+open Base
+
+type t =
+  | Float of (float, Bigarray.float64_elt) Tensor.t
+  | Int64 of (int64, Bigarray.int64_elt) Tensor.t
+
+let create_1_float f = Float (Tensor.create_1_float f)
+let create_1_int64 i = Int64 (Tensor.create_1_int64 i)
+
+let of_int_array a =
+  Int64
+    (Tensor.of_array1
+       (Shape.of_array [| Array.length a |])
+       (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int)))
+
+let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i
diff --git a/lib/onnx/onnx.mli b/lib/nir/gentensor.mli
similarity index 77%
rename from lib/onnx/onnx.mli
rename to lib/nir/gentensor.mli
index d7e52c2872088defca6c63cf1ffc5ed3bb566788..9c9e1c31e4406063d4bb7767b9f2daa6742cee72 100644
--- a/lib/onnx/onnx.mli
+++ b/lib/nir/gentensor.mli
@@ -20,19 +20,11 @@
 (*                                                                        *)
 (**************************************************************************)
 
-module G = Ir.Nier_cfg.NierCFGFloat
+type t =
+  | Float of (float, Bigarray.float64_elt) Tensor.t
+  | Int64 of (int64, Bigarray.int64_elt) Tensor.t
 
-type t = private {
-  n_inputs : int;  (** Number of inputs. *)
-  n_outputs : int;  (** Number of outputs. *)
-  nier : (G.t, string) Result.t;  (** Intermediate representation. *)
-}
-(** ONNX model metadata and intermediate representation. *)
-
-val parse : string -> (t, string) Result.t
-(** Parse an ONNX file into a NIER. *)
-
-val write : G.t -> string -> unit
-(** Write a NIER into an ONNX file. *)
-
-module Simple = Simple
+val create_1_float : float -> t
+val create_1_int64 : int64 -> t
+val of_int_array : int array -> t
+val shape : t -> Shape.t
diff --git a/lib/nir/ngraph.ml b/lib/nir/ngraph.ml
new file mode 100644
index 0000000000000000000000000000000000000000..9abfd51034d82b82ff9d98c64e82bc67e6534bc6
--- /dev/null
+++ b/lib/nir/ngraph.ml
@@ -0,0 +1,120 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+open Base
+
+type t = {
+  output : Node.t;
+  succs : (int, Node.t list) Base.Hashtbl.t;
+}
+
+let output t = t.output
+
+(** TODO: some other invariants must be checked e.g only one input *)
+let create output =
+  let succs = Base.Hashtbl.create (module Base.Int) in
+  let check_node node =
+    List.iter
+      ~f:(fun p -> Base.Hashtbl.add_multi succs ~key:p.id ~data:node)
+      (Node.preds node)
+  in
+  (* Add the key of the output nodes so that all the nodes are in succs *)
+  Base.Hashtbl.add_exn succs ~key:output.Node.id ~data:[];
+  Node.iter_rec check_node output;
+  { output; succs }
+
+let input_shape g =
+  let r = ref None in
+  let check_node n =
+    match n.Node.descr with Input { shape } -> r := Some shape | _ -> ()
+  in
+  Node.iter_rec check_node g.output;
+  Option.value_exn !r
+
+let succs t node = Base.Hashtbl.find_exn t.succs node.Node.id
+let iter_vertex f t = Node.iter_rec f t.output
+
+let iter_succ f t node =
+  List.iter ~f (Base.Hashtbl.find_multi t.succs node.Node.id)
+
+let pp fmt t =
+  iter_vertex (fun v -> Fmt.pf fmt "@[%a@]@ " Node.pp_descr v.descr) t
+
+let pp_debug fmt t =
+  iter_vertex
+    (fun v ->
+      Fmt.pf fmt "@[%i: %a(%a) : %a@]@ " v.id Node.pp_descr v.descr
+        Fmt.(list ~sep:comma (using (fun x -> x.Node.id) int))
+        (Node.preds v) Shape.pp v.shape)
+    t
+
+let nodes t =
+  let l = ref [] in
+  iter_vertex (fun v -> l := v :: !l) t;
+  !l
+
+module M = Graph.Topological.Make
+
+module GFloat = struct
+  type nonrec t = t
+
+  let iter_edges_e f t =
+    iter_vertex (fun n -> List.iter ~f:(fun n' -> f (n', n)) (Node.preds n)) t
+
+  module V = Node
+
+  module E = struct
+    type t = V.t * V.t
+
+    let src = fst
+    let dst = snd
+  end
+
+  let iter_vertex = iter_vertex
+  let iter_succ = iter_succ
+end
+
+module Dot = Graph.Graphviz.Dot (struct
+  include GFloat (* use the graph module from above *)
+
+  let node_label (v : Node.t) = Node.show v
+  let edge_attributes (_, _) = []
+  let default_edge_attributes _ = []
+  let get_subgraph _ = None
+  let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ]
+  let vertex_name (v : Node.t) = Int.to_string v.id
+  let default_vertex_attributes _ = []
+  let graph_attributes _ = []
+end)
+
+let grapheasy g =
+  try
+    let cin, cout =
+      Unix.open_process_args "graph-easy"
+        [| "graph-easy"; "--from=graphviz"; "--as=boxart" |]
+    in
+    Dot.output_graph cout g;
+    Stdlib.close_out cout;
+    let ascii = Stdio.In_channel.input_all cin in
+    ignore (Unix.close_process (cin, cout));
+    ascii
+  with exn ->
+    Fmt.str "Error graph-easy call: %s" (Stdlib.Printexc.to_string exn)
diff --git a/lib/nir/ngraph.mli b/lib/nir/ngraph.mli
new file mode 100644
index 0000000000000000000000000000000000000000..a96ea8cdf229ef48482adc9524c22612f6ace6b4
--- /dev/null
+++ b/lib/nir/ngraph.mli
@@ -0,0 +1,74 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+(** {1 Neural Intermediate Representation (NIR)} *)
+
+(** NIR is a graph describing a machine learning model control flow.
+
+    A graph is described starting from its output node. *)
+
+type t
+
+val pp : t Fmt.t
+val pp_debug : t Fmt.t
+
+val create : Node.t -> t
+(** Create a network from its output node.t *)
+
+val output : t -> Node.t
+(** Output node.t of the network *)
+
+val nodes : t -> Node.t list
+(** Output nodes of the network *)
+
+val input_shape : t -> Shape.t
+(** Input shape of the network *)
+
+val succs : t -> Node.t -> Node.t list
+(** successors of a node.t *)
+
+val iter_vertex : (Node.t -> unit) -> t -> unit
+val iter_succ : (Node.t -> unit) -> t -> Node.t -> unit
+val grapheasy : t -> string
+
+(** Respect some OcamlGraph signature *)
+module GFloat : sig
+  type nonrec t = t
+
+  module V = Node
+
+  module E : sig
+    type t = V.t * V.t
+
+    val src : t -> V.t
+    val dst : t -> V.t
+  end
+
+  val iter_vertex : (V.t -> unit) -> t -> unit
+  val iter_succ : (V.t -> unit) -> t -> V.t -> unit
+  val iter_edges_e : (E.t -> unit) -> t -> unit
+end
+
+module Dot : sig
+  val fprint_graph : Format.formatter -> GFloat.t -> unit
+  val output_graph : out_channel -> GFloat.t -> unit
+end
diff --git a/lib/nir/node.ml b/lib/nir/node.ml
new file mode 100644
index 0000000000000000000000000000000000000000..656d928f12fe52ed3a5d1d5c0b804622fb148b4b
--- /dev/null
+++ b/lib/nir/node.ml
@@ -0,0 +1,479 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+open Base
+
+type ty =
+  | Float
+  | Int64
+[@@deriving show]
+
+(** TODO: add the information needed to compute the shape *)
+type descr =
+  | Constant of { data : Gentensor.t }
+  | Add of {
+      input1 : t;
+      input2 : t;
+    }
+  | Sub of {
+      input1 : t;
+      input2 : t;
+    }
+  | Mul of {
+      input1 : t;
+      input2 : t;
+    }
+  | Div of {
+      input1 : t;
+      input2 : t;
+    }
+  | Matmul of {
+      input1 : t;
+      input2 : t;
+    }
+  | Gemm of {
+      inputA : t;
+      inputB : t;
+      inputC : t option;
+      alpha : float;
+      beta : float;
+      transA : bool;
+      transB : bool;
+    }
+  | LogSoftmax
+  | ReLu of { input : t }
+  | Transpose of {
+      input : t;
+        (* called "data" in ONNX documentation :
+           https://onnx.ai/onnx/operators/onnx__Transpose.html*)
+      perm : int list;
+    }
+  | Squeeze of {
+      data : t;
+      axes : t option; (* data int64 *)
+    }
+  | MaxPool
+  | Conv
+  | Reshape of {
+      input : t;
+      shape : t; (* data int64 *)
+    }
+  | Flatten of {
+      input : t;
+      axis : int;
+    }
+  | Identity of { input : t }
+  | Input of { shape : Shape.t }
+  | RW_Linearized_ReLu
+  | Concat of {
+      inputs : t list;
+      axis : int;
+    }
+  | Gather of {
+      input : t;
+      indices : t;
+      axis : int;
+    }
+  | ReduceSum of {
+      input : t;
+      axes : t option;
+      keepdims : int;
+      noop_with_empty_axes : int;
+    }
+  | GatherND of {
+      data : t;
+      indices : t;
+      batch_dims : int;
+    }
+  | RandomNormal of {
+      dtype : int;
+      mean : float;
+      scale : float;
+      seed : float;
+      shape : int array;
+    }
+  | Abs of { input : t }
+  | Log of { input : t }
+
+and t = {
+  id : int;
+  descr : descr; [@printer fun fmt d -> pp_descr fmt d]
+  shape : Shape.t;
+  ty : ty;
+}
+
+let pp_descr fmt p =
+  match p with
+  | Input { shape } -> Fmt.pf fmt "Input: %a" Shape.pp shape
+  | Transpose { perm; _ } ->
+    Fmt.pf fmt "Transpose: [%a]" Fmt.(list ~sep:semi int) perm
+  | Constant { data = Int64 b } when Shape.size (Tensor.shape b) < 3 ->
+    Fmt.pf fmt "Constant[%a]" Fmt.(list ~sep:comma int64) (Tensor.flatten b)
+  | Constant _ -> Fmt.pf fmt "Constant"
+  | Add _ -> Fmt.pf fmt "Add"
+  | Sub _ -> Fmt.pf fmt "Sub"
+  | Mul _ -> Fmt.pf fmt "Mul"
+  | Div _ -> Fmt.pf fmt "Div"
+  | Matmul _ -> Fmt.pf fmt "Matmul"
+  | Gemm _ -> Fmt.pf fmt "Gemm"
+  | LogSoftmax -> Fmt.pf fmt "LogSoftmax"
+  | ReLu _ -> Fmt.pf fmt "ReLu"
+  | Squeeze _ -> Fmt.pf fmt "Squeeze"
+  | MaxPool -> Fmt.pf fmt "MaxPool"
+  | Conv -> Fmt.pf fmt "Conv"
+  | Reshape _ -> Fmt.pf fmt "Reshape"
+  | Flatten _ -> Fmt.pf fmt "Flatten"
+  | Identity _ -> Fmt.pf fmt "Identity"
+  | RW_Linearized_ReLu -> Fmt.pf fmt "RW_Linearized_ReLu"
+  | Concat { axis; _ } -> Fmt.pf fmt "Concat[%i]" axis
+  | Gather _ -> Fmt.pf fmt "Gather"
+  | ReduceSum _ -> Fmt.pf fmt "ReduceSum"
+  | GatherND _ -> Fmt.pf fmt "GatherND"
+  | RandomNormal _ -> Fmt.pf fmt "RandomNormal"
+  | Abs _ -> Fmt.pf fmt "Abs"
+  | Log _ -> Fmt.pf fmt "Log"
+
+let show_descr t = Fmt.str "%a" pp_descr t
+let compare { id = id1; _ } { id = id2; _ } = Int.compare id1 id2
+let equal { id = id1; _ } { id = id2; _ } = Int.equal id1 id2
+let hash { id; _ } = id
+let sexp_of_t node = Base.Int.sexp_of_t node.id
+let pp fmt n = Fmt.pf fmt "@[%i: %a@]" n.id pp_descr n.descr
+let show n = Fmt.str "%a" pp n
+
+include Base.Comparator.Make (struct
+  type nonrec t = t
+
+  let compare = compare
+  let sexp_of_t = sexp_of_t
+end)
+
+let rec compute_shape n = n.shape
+
+and compute_shape_descr = function
+  | Add { input1; _ }
+  | Div { input1; _ }
+  | Mul { input1; _ }
+  | Sub { input1; _ } ->
+    compute_shape input1
+  | Flatten { input; axis } ->
+    (* (d_0 X d_1 … d_(axis-1), d_axis X d_(axis+1) … X dn). *)
+    let shape = compute_shape input in
+    let d1 = ref 1 in
+    let d2 = ref 1 in
+    for i = 0 to axis - 1 do
+      d1 := !d1 * Shape.get shape i
+    done;
+    for i = axis to Shape.rank shape - 1 do
+      d2 := !d1 * Shape.get shape i
+    done;
+    Shape.of_list [ !d1; !d2 ]
+  | Input { shape } -> shape
+  | ReLu { input } -> compute_shape input
+  | Transpose { input; perm = [] } ->
+    compute_shape input |> Shape.to_list |> List.rev |> Shape.of_list
+  | Transpose { input; perm } ->
+    let shape = compute_shape input in
+    let rank = Shape.rank shape in
+    assert (Int.equal rank (List.length perm));
+    let shape' = Array.create ~len:rank 0 in
+    let shape = Shape.to_array shape in
+    Base.List.iteri perm ~f:(fun i j -> Array.set shape' i (Array.get shape j));
+    Shape.of_array shape'
+  | Constant { data } -> Gentensor.shape data
+  | Concat { inputs; axis } ->
+    let shapes = List.map ~f:compute_shape inputs in
+    let shape = List.hd_exn shapes in
+    let axis = if axis < 0 then Shape.rank shape + axis else axis in
+    let l = List.map ~f:(fun s -> Shape.get s axis) shapes in
+    let i = List.reduce_exn ~f:( + ) l in
+    Shape.set shape axis i
+  | Gather { input; indices; axis } -> (
+    let input_shape = compute_shape input in
+    let indices_shape = compute_shape indices in
+    let axis = if axis < 0 then Shape.rank input_shape + axis else axis in
+    match List.split_n (Shape.to_list input_shape) axis with
+    | _, [] -> failwith "axis is bigger than shape rank"
+    | before, _ :: after ->
+      Shape.of_list (before @ Shape.to_list indices_shape @ after))
+  | Matmul { input1; input2 } ->
+    let pad_left = function
+      | [] -> failwith "Impossible to pad empty shape"
+      | [ a ] -> [ 1; a ]
+      | x -> x
+    in
+    let pad_right = function
+      | [] -> failwith "Impossible to pad empty shape"
+      | [ a ] -> [ a; 1 ]
+      | x -> x
+    in
+    let rec one_padding l i =
+      if i <= 0 then l else one_padding (1 :: l) (i - 1)
+    in
+    (* Expected semantic:
+     * Matrix multiplication C = AB
+     * A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
+     * shape of b: b_sh
+     * shape of a: a_sh
+     * shape of c: c_sh
+     *)
+    let check_matmul_size_ab ~a_sh ~b_sh =
+      let adim2 = pad_left a_sh in
+      let bdim2 = pad_right b_sh in
+      let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in
+      let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in
+      let rec infer_csize acc ad bd =
+        match (ad, bd) with
+        | [ m; n ], [ nn; p ] ->
+          if nn = n
+          then List.rev_append acc [ m; p ]
+          else failwith "size of matrices not adequate"
+        | a :: la, b :: lb ->
+          if a = b
+          then infer_csize (a :: acc) la lb
+          else if a = 1
+          then infer_csize (b :: acc) la lb
+          else if b = 1
+          then infer_csize (a :: acc) la lb
+          else failwith "Checking matmul_size failed: one discordance"
+        | _, _ -> failwith "Checking matmul_size failed"
+      in
+      infer_csize [] adim bdim
+    in
+    (* TODO: in case of pad_left and pad_right remove the added dimension. But
+       it is not clear what must be done when broadcasting is done *)
+    Shape.of_list
+      (check_matmul_size_ab
+         ~a_sh:(Shape.to_list (compute_shape input1))
+         ~b_sh:(Shape.to_list (compute_shape input2)))
+  | Reshape { input; shape; _ } ->
+    let shape =
+      match shape.descr with
+      | Constant { data = Int64 a } ->
+        List.map ~f:Int64.to_int_exn (Tensor.flatten a)
+      | _ ->
+        (* Some constant propagation could be useful in some cases eg. patch-1
+           VNNcomp *)
+        failwith "non-constant shape in reshape not supported"
+    in
+    List.iter shape ~f:(function
+      | -1 | 0 -> failwith "not implemented 0 -1 in shape for reshape"
+      | _ -> ());
+    let out = Shape.of_list shape in
+    if Shape.size out <> Shape.size input.shape
+    then
+      failwith
+        "Reshape: shape of input and shape given have not the same number of \
+         elements";
+    out
+  | Gemm { inputA; inputB; inputC = _; alpha = _; beta = _; transA; transB } ->
+    let rank2 i =
+      match Shape.to_array_unsafe i.shape with
+      | [| k; n |] -> (k, n)
+      | _ -> failwith "Gemm input must be of size 2"
+    in
+    let tr trans (k, n) = if trans then (n, k) else (k, n) in
+    let a1, a2 = tr transA @@ rank2 inputA in
+    let b1, b2 = tr transB @@ rank2 inputB in
+    if not (Int.equal a2 b1)
+    then
+      Fmt.failwith "Gemm (M:%i,K:%i) (K:%i,N:%i) -> (M:%i,N:%i)" a1 a2 b1 b2 a1
+        b2;
+    Shape.of_array [| a1; b2 |]
+  | ( LogSoftmax | Squeeze _ | MaxPool | Conv | Identity _ | RW_Linearized_ReLu
+    | ReduceSum _ | GatherND _ | RandomNormal _ | Abs _ | Log _ ) as n ->
+    failwith (Fmt.str "todo compute shape : %a" pp_descr n)
+
+let compute_ty : _ -> ty = function
+  | Constant { data = Float _ } -> Float
+  | Constant { data = Int64 _ } -> Int64
+  | _ -> Float
+
+let create =
+  let c = ref (-1) in
+  fun descr ->
+    Int.incr c;
+    { id = !c; descr; shape = compute_shape_descr descr; ty = compute_ty descr }
+
+let constant_int_array a = create (Constant { data = Gentensor.of_int_array a })
+
+let reshape shape node =
+  if Shape.equal node.shape shape
+  then node
+  else
+    create
+      (Reshape
+         { input = node; shape = constant_int_array (Shape.to_array shape) })
+
+let gather_int_as_matmul input i =
+  let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in
+  let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in
+  Array.set selector i Float.one;
+  let selector =
+    Gentensor.Float
+      (Tensor.of_array1
+         (Shape.of_array [| Array.length selector; 1 |])
+         (Bigarray.Array1.of_array Float64 C_layout selector))
+  in
+  let input2 = create (Constant { data = selector }) in
+  let result = create (Matmul { input1; input2 }) in
+  reshape (Shape.of_array [| 1 |]) result
+
+let gather_int ?(encode = true) input i =
+  if encode
+  then gather_int_as_matmul input i
+  else
+    let indices =
+      create (Constant { data = Gentensor.create_1_int64 (Int64.of_int i) })
+    in
+    create (Gather { input; indices; axis = 0 })
+
+let mul_float input f =
+  let input1 = reshape (Shape.of_array [| 1; 1 |]) input in
+  let f = Array.create ~len:1 f in
+  let f =
+    Gentensor.Float
+      (Tensor.of_array1
+         (Shape.of_array [| Array.length f; 1 |])
+         (Bigarray.Array1.of_array Float64 C_layout f))
+  in
+  let input2 = create (Constant { data = f }) in
+  let result = create (Matmul { input1; input2 }) in
+  reshape (Shape.of_array [| 1 |]) result
+
+let div_float ?(encode = true) input f =
+  if encode
+  then
+    let f = Float.one /. f in
+    mul_float input f
+  else
+    let input1 = reshape (Shape.of_array [| 1; 1 |]) input in
+    let f = Array.create ~len:1 f in
+    let f =
+      Gentensor.Float
+        (Tensor.of_array1
+           (Shape.of_array [| Array.length f; 1 |])
+           (Bigarray.Array1.of_array Float64 C_layout f))
+    in
+    let input2 = create (Constant { data = f }) in
+    let result = create (Div { input1; input2 }) in
+    reshape (Shape.of_array [| 1 |]) result
+
+let concat_0 = function
+  | [ n ] -> n
+  | [] -> failwith "empty concat"
+  | inputs -> create (Concat { inputs; axis = 0 })
+
+let preds node =
+  match node.descr with
+  | Constant _ | Input _ -> []
+  | Add { input1; input2 }
+  | Sub { input1; input2 }
+  | Mul { input1; input2 }
+  | Div { input1; input2 }
+  | Matmul { input1; input2 } ->
+    [ input1; input2 ]
+  | Gather { input; indices; axis = _ } -> [ input; indices ]
+  | GatherND { data; indices; batch_dims = _ } -> [ data; indices ]
+  | ReLu { input } | Abs { input } | Log { input } -> [ input ]
+  | Concat { inputs; axis = _ } -> inputs
+  | ReduceSum { input; axes = Some x; _ } -> [ input; x ]
+  | ReduceSum { input; axes = None; _ } -> [ input ]
+  | RandomNormal _ -> []
+  | Transpose { input; _ } -> [ input ]
+  | Flatten { input; _ } -> [ input ]
+  | Identity { input } -> [ input ]
+  | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ]
+  | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ]
+  | Squeeze { data; _ } -> [ data ]
+  | Reshape { input; shape; _ } -> [ input; shape ]
+  | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> []
+
+let map f n =
+  match n.descr with
+  | Constant _ | Input _ -> n
+  | Add { input1; input2 } ->
+    create (Add { input1 = f input1; input2 = f input2 })
+  | Sub { input1; input2 } ->
+    create (Sub { input1 = f input1; input2 = f input2 })
+  | Mul { input1; input2 } ->
+    create (Mul { input1 = f input1; input2 = f input2 })
+  | Div { input1; input2 } ->
+    create (Div { input1 = f input1; input2 = f input2 })
+  | Matmul { input1; input2 } ->
+    create (Matmul { input1 = f input1; input2 = f input2 })
+  | ReLu { input } -> create (ReLu { input = f input })
+  | Abs { input } -> create (Abs { input = f input })
+  | Log { input } -> create (Log { input = f input })
+  | RandomNormal _ as descr -> create descr
+  | ReduceSum { input; axes; keepdims; noop_with_empty_axes } ->
+    create (ReduceSum { input = f input; axes; keepdims; noop_with_empty_axes })
+  | Gather { input; indices; axis } ->
+    create (Gather { input = f input; indices = f indices; axis })
+  | GatherND { data; indices; batch_dims } ->
+    create (GatherND { data = f data; indices = f indices; batch_dims })
+  | Transpose t -> create (Transpose { t with input = f t.input })
+  | Flatten t -> create (Flatten { t with input = f t.input })
+  | Identity { input } -> create (Identity { input = f input })
+  | Concat { inputs; axis } ->
+    create (Concat { inputs = List.map ~f inputs; axis })
+  | Gemm t ->
+    create
+      (Gemm
+         {
+           t with
+           inputA = f t.inputA;
+           inputB = f t.inputB;
+           inputC = Base.Option.map t.inputC ~f;
+         })
+  | Squeeze t -> create (Squeeze { t with data = f t.data })
+  | Reshape t -> create (Reshape { t with input = f t.input })
+  | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> n (* todo *)
+
+(* let map_rec f node = let h = Base.Hashtbl.create (module Base.Int) in let rec
+   aux n = Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n)) in
+   aux node *)
+
+let replace_input f node =
+  let h = Base.Hashtbl.create (module Base.Int) in
+  let rec aux n =
+    Base.Hashtbl.find_or_add h n.id ~default:(fun () ->
+      match n.descr with Input _ -> f () | _ -> map aux n)
+  in
+  aux node
+
+(* iter on the nodes accessible from [node] ([node] comprised) without
+   repetition *)
+let map_rec f node =
+  let h = Base.Hashtbl.create (module Base.Int) in
+  let rec aux n =
+    Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n))
+  in
+  aux node
+
+let iter_rec f node =
+  let h = Base.Hashtbl.create (module Base.Int) in
+  let rec aux n =
+    Base.Hashtbl.find_or_add h n.id ~default:(fun () ->
+      List.iter ~f:aux (preds n);
+      f n)
+  in
+  aux node
diff --git a/lib/nir/node.mli b/lib/nir/node.mli
new file mode 100644
index 0000000000000000000000000000000000000000..e23d5a57ffaf9929c7112607eecb1b43a456dacb
--- /dev/null
+++ b/lib/nir/node.mli
@@ -0,0 +1,171 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+(** {1 Nodes descriptions} *)
+
+(** A node is composed of
+
+    - a unique [id] of type int
+    - a node description of type [descr]
+
+    [descr] describes several operations. When an operation shares the same name
+    as an ONNX operation, it follows the standard defined in the ONNX IR v8 and
+    ONNX Opset v13 standards, described here:
+    https://onnx.ai/onnx/operators/index.html.
+
+    Nodes only require their inputs: it is assumed that a node only returns one
+    value. *)
+
+open Base
+
+type ty =
+  | Float
+  | Int64
+[@@deriving show]
+
+type descr =
+  | Constant of { data : Gentensor.t }
+      (** A constant tensor, used to store non-varying parameters during
+          inference. *)
+  | Add of {
+      input1 : t;
+      input2 : t;
+    }
+  | Sub of {
+      input1 : t;
+      input2 : t;
+    }
+  | Mul of {
+      input1 : t;
+      input2 : t;
+    }
+  | Div of {
+      input1 : t;
+      input2 : t;
+    }
+  | Matmul of {
+      input1 : t;
+      input2 : t;
+    }
+  | Gemm of {
+      inputA : t;
+      inputB : t;
+      inputC : t option;
+      alpha : float;
+      beta : float;
+      transA : bool;
+      transB : bool;
+    }
+  | LogSoftmax
+  | ReLu of { input : t }
+  | Transpose of {
+      input : t;
+        (** Called "data" in ONNX documentation :
+            https://onnx.ai/onnx/operators/onnx__Transpose.html .*)
+      perm : int list;
+    }
+  | Squeeze of {
+      data : t;
+      axes : t option; (* Expects a int64 . *)
+    }
+  | MaxPool
+  | Conv
+  | Reshape of {
+      input : t;
+      shape : t; (* Expects a int64 *)
+    }
+  | Flatten of {
+      input : t;
+      axis : int;
+    }
+  | Identity of { input : t }
+  | Input of { shape : Shape.t }
+  | RW_Linearized_ReLu
+  | Concat of {
+      inputs : t list;
+      axis : Base.int;
+    }
+  | Gather of {
+      input : t;
+      indices : t;
+      axis : int;
+    }
+  | ReduceSum of {
+      input : t;
+      axes : t option;
+      keepdims : int;
+      noop_with_empty_axes : int;
+    }
+  | GatherND of {
+      data : t;
+      indices : t;
+      batch_dims : int;
+    }
+  | RandomNormal of {
+      dtype : int;
+      mean : float;
+      scale : float;
+      seed : float;
+      shape : int array;
+    }
+  | Abs of { input : t }
+  | Log of { input : t }
+[@@deriving show]
+
+and t = private {
+  id : int;
+  descr : descr;
+  shape : Shape.t;
+  ty : ty;  (** Describes the shape of the result of the node computation. *)
+}
+
+val equal : t -> t -> bool
+
+include Base.Hashtbl.Key.S with type t := t
+include Base.Comparator.S with type t := t
+
+val create : descr -> t
+(** [create descr] returns a value of type node with proper indexing and the
+    shape according to the ONNX semantic. *)
+
+val gather_int : ?encode:bool -> t -> int -> t
+
+val map : (t -> t) -> t -> t
+(** [map f n] replace the direct inputs [i] of n by [f i] *)
+
+val map_rec : (t -> t) -> t -> t
+(** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by [f i] *)
+
+val replace_input : (unit -> t) -> t -> t
+(** [replace_input f n] replace the input in [n] by [f ()] *)
+
+val preds : t -> t list
+(** Direct predecessors of a t *)
+
+val iter_rec : (t -> unit) -> t -> unit
+(** Iterate on the predecessors of a t and itself. Repect topological order. *)
+
+val compute_shape : t -> Shape.t
+val mul_float : t -> float -> t
+val div_float : ?encode:bool -> t -> float -> t
+val concat_0 : t list -> t
+val reshape : Shape.t -> t -> t
diff --git a/lib/nir/shape.ml b/lib/nir/shape.ml
new file mode 100644
index 0000000000000000000000000000000000000000..b07e361922a9549b1834803fc43b00418abe736e
--- /dev/null
+++ b/lib/nir/shape.ml
@@ -0,0 +1,58 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Base
+
+type t = int array [@@deriving ord, eq]
+
+let to_array = Array.copy
+let to_array_unsafe x = x
+let of_array = Array.copy
+let to_list = Array.to_list
+let of_list = Array.of_list
+let rank = Array.length
+let size t = Array.fold t ~f:( * ) ~init:1
+let pp fmt x = Fmt.pf fmt "[%a]" Fmt.(array ~sep:semi int) x
+let show s = Fmt.str "%a" pp s
+let get = Array.get
+
+let set t k v =
+  let t = Array.copy t in
+  Array.set t k v;
+  t
+
+let row_major t a =
+  assert (Array.length t = Array.length a);
+  let r = ref 0 in
+  for i = 0 to Array.length t - 1 do
+    r := (!r * t.(i)) + a.(i)
+  done;
+  !r
+
+let unrow_major t i =
+  let r = ref i in
+  let a = Array.create ~len:(Array.length t) 0 in
+  for i = Array.length t - 1 downto 0 do
+    a.(i) <- !r % t.(i);
+    r := !r / t.(i)
+  done;
+  a
diff --git a/lib/nir/shape.mli b/lib/nir/shape.mli
new file mode 100644
index 0000000000000000000000000000000000000000..79934bb842614e376ae7799cdf1ac21888b8c108
--- /dev/null
+++ b/lib/nir/shape.mli
@@ -0,0 +1,35 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+type t [@@deriving show, ord, eq]
+
+val to_array : t -> int array
+val of_array : int array -> t
+val to_list : t -> int list
+val of_list : int list -> t
+val rank : t -> int
+val size : t -> int
+val get : t -> int -> int
+val set : t -> int -> int -> t
+val row_major : t -> int array -> int
+val unrow_major : t -> int -> int array
+val to_array_unsafe : t -> int array
diff --git a/lib/nir/tensor.ml b/lib/nir/tensor.ml
new file mode 100644
index 0000000000000000000000000000000000000000..1eb1034a0eb664761d5f121abb32abd0f67897e7
--- /dev/null
+++ b/lib/nir/tensor.ml
@@ -0,0 +1,57 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+open Base
+
+type ('a, 'b) t = ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t
+
+let copy t =
+  let t' = Bigarray.Genarray.(create (kind t) Bigarray.c_layout (dims t)) in
+  Bigarray.Genarray.blit t t';
+  t'
+
+let of_tensor = copy
+let to_tensor = copy
+
+let create_1_float v =
+  let t =
+    Bigarray.Genarray.(create Bigarray.float64 Bigarray.c_layout [| 1 |])
+  in
+  Bigarray.Genarray.set t [| 0 |] v;
+  t
+
+let create_1_int64 v =
+  let t = Bigarray.Genarray.(create Bigarray.int64 Bigarray.c_layout [| 1 |]) in
+  Bigarray.Genarray.set t [| 0 |] v;
+  t
+
+let shape x = Shape.of_array @@ Bigarray.Genarray.dims x
+
+let flatten t =
+  let a = Bigarray.reshape_1 t (Shape.size (shape t)) in
+  List.init (Bigarray.Array1.dim a) ~f:(fun i -> Bigarray.Array1.get a i)
+
+let of_array1 shape t =
+  Bigarray.reshape
+    (copy @@ Bigarray.genarray_of_array1 t)
+    (Shape.to_array_unsafe shape)
+
+let get = Bigarray.Genarray.get
diff --git a/lib/nir/tensor.mli b/lib/nir/tensor.mli
new file mode 100644
index 0000000000000000000000000000000000000000..0d29ee42e95200586e6fb95dc473ef87fea36dbf
--- /dev/null
+++ b/lib/nir/tensor.mli
@@ -0,0 +1,63 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+(** {1 Immutable tensor module} *)
+
+(** Tensors are multidimensional arrays used to represent numerical such as a
+    neural network paramters.
+
+    This library relies on Bigarray.Genarray to instanciante tensors. *)
+
+(** [get t idx] returns the value in tensor [t] stored at coordinates [idx].
+    Throw an error if the coordinate is invalid.*)
+
+(** [set_idx t idx v] sets value [v] for tensor [t] at [idx]. Throw an error if
+    the coordinate is invalid.*)
+
+type ('a, 'b) t
+
+val of_tensor : ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> ('a, 'b) t
+val to_tensor : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t
+
+val create_1_float : float -> (float, Bigarray.float64_elt) t
+(** [create_1_float f] returns an unidimentional tensor with one floating point
+    value [f]. *)
+
+val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t
+(** [create_1_int64 i] returns an unidimentional tensor with one int64 value
+    [i]. *)
+
+val shape : ('a, 'b) t -> Shape.t
+
+val flatten : ('a, 'b) t -> 'a list
+(** [flatten t] returns all values stored in [t] as a flat list. *)
+
+val of_array1 :
+  Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t
+(* [of_array ] *)
+
+val get : ('a, 'b) t -> int array -> 'a
+(** [get t sh] returns the value stored at coordinates [sh] in [t].
+
+    @raise Invalid_argument
+      if [sh] does not exactly match the shape of [t], or if [sh] is
+      out-of-bounds. *)
diff --git a/lib/onnx/dune b/lib/onnx/dune
index 1dda954451b84ed7d64565d34978d2541cc04fa0..782b1e9e717874a9b68a8d74b9c4946c67544ef7 100644
--- a/lib/onnx/dune
+++ b/lib/onnx/dune
@@ -6,7 +6,7 @@
   stdio
   ocaml-protoc-plugin
   ocplib-endian
-  caisar.ir
+  caisar.nir
   caisar_logging)
  (synopsis "ONNX parser for CAISAR"))
 
diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml
deleted file mode 100644
index a647157b3620cccba1e603083b080d20b190be01..0000000000000000000000000000000000000000
--- a/lib/onnx/onnx.ml
+++ /dev/null
@@ -1,694 +0,0 @@
-(**************************************************************************)
-(*                                                                        *)
-(*  This file is part of CAISAR.                                          *)
-(*                                                                        *)
-(*  Copyright (C) 2023                                                    *)
-(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
-(*         alternatives)                                                  *)
-(*                                                                        *)
-(*  You can redistribute it and/or modify it under the terms of the GNU   *)
-(*  Lesser General Public License as published by the Free Software       *)
-(*  Foundation, version 2.1.                                              *)
-(*                                                                        *)
-(*  It is distributed in the hope that it will be useful,                 *)
-(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
-(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
-(*  GNU Lesser General Public License for more details.                   *)
-(*                                                                        *)
-(*  See the GNU Lesser General Public License version 2.1                 *)
-(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
-(*                                                                        *)
-(**************************************************************************)
-
-open Base
-module Format = Stdlib.Format
-module Fun = Stdlib.Fun
-module Oproto = Onnx_protoc (* Autogenerated during compilation *)
-module Oprotom = Oproto.Onnx.ModelProto
-module NCFG = Ir.Nier_cfg
-module G = NCFG.NierCFGFloat
-module NSimple = Ir.Nier_simple
-module GFloat = Ir.Nier_simple.GFloat
-module Dot = Ir.Nier_simple.Dot
-
-exception ParseError of string
-
-type t = {
-  n_inputs : int; (* Number of inputs. *)
-  n_outputs : int; (* Number of outputs. *)
-  nier : (G.t, string) Result.t; (* Intermediate representation. *)
-}
-
-(* ONNX format handling. *)
-type op_attribute = Oproto.Onnx.AttributeProto.t
-
-type tensordata =
-  | Raw of bytes
-  | Float of float list
-
-let (no_attr : op_attribute) =
-  {
-    name = None;
-    ref_attr_name = None;
-    doc_string = None;
-    type' = None;
-    f = None;
-    i = None;
-    s = None;
-    t = None;
-    g = None;
-    floats = [];
-    ints = [];
-    strings = [];
-    tensors = [];
-    graphs = [];
-    sparse_tensor = None;
-    tp = None;
-    sparse_tensors = [];
-    type_protos = [];
-  }
-
-let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
-  match List.nth s 0 with
-  | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
-    ->
-    v
-  | _ -> []
-
-let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
-  List.fold ~init:1 dim ~f:(fun acc x ->
-    match x.value with
-    | `Dim_value v -> acc * Int64.to_int_exn v
-    | `Dim_param _ -> acc
-    | `not_set -> acc)
-
-let get_input_output_dim (model : Oprotom.t) =
-  let input_shape, output_shape =
-    match model.graph with
-    | Some g -> (get_nested_dims g.input, get_nested_dims g.output)
-    | _ -> ([], [])
-  in
-  (* TODO: here we only get the flattened dimension of inputs and outputs, but
-     more interesting parsing could be done later on. *)
-  let input_flat_dim = flattened_dim input_shape in
-  let output_flat_dim = flattened_dim output_shape in
-  (input_flat_dim, output_flat_dim)
-
-let produce_cfg (g : Oproto.Onnx.GraphProto.t) =
-  let open Oproto.Onnx in
-  let nodes = g.node
-  and inputs = g.input
-  and outputs = g.output
-  and initi = g.initializer' in
-  let fold_vip_names acc n =
-    match n.ValueInfoProto.name with
-    | Some str -> Some str :: acc
-    | None -> None :: acc
-  in
-  let i_nodes, o_nodes =
-    ( List.fold inputs ~init:[] ~f:fold_vip_names,
-      List.fold outputs ~init:[] ~f:fold_vip_names )
-  and c_nodes = List.init (List.length nodes) ~f:(fun _ -> None) in
-  let fold_nodes_ops_cfg ns =
-    let get_node_operator_cfg x =
-      match x.NodeProto.op_type with
-      | None -> NCFG.Node.NO_OP
-      | Some o -> (
-        match o with
-        | "Add" -> NCFG.Node.Add
-        | "Sub" -> NCFG.Node.Sub
-        | "Mul" -> NCFG.Node.Mul
-        | "Div" -> NCFG.Node.Div
-        | "Relu" -> NCFG.Node.ReLu
-        | "MatMul" -> NCFG.Node.Matmul
-        | "Gemm" -> NCFG.Node.Gemm
-        | "LogSoftmax" -> NCFG.Node.LogSoftmax
-        | "Transpose" -> NCFG.Node.Transpose
-        | "Squeeze" -> NCFG.Node.Squeeze
-        | "MaxPool" -> NCFG.Node.MaxPool
-        | "Constant" -> NCFG.Node.Constant
-        | "Conv" -> NCFG.Node.Conv
-        | "Reshape" -> NCFG.Node.Reshape
-        | "Flatten" -> NCFG.Node.Flatten
-        | "Identity" -> NCFG.Node.Identity
-        | "Gather" -> NCFG.Node.Gather
-        (* | "ReduceSum" -> NCFG.Node.ReduceSum | "GatherND" ->
-           NCFG.Node.GatherND | "RandomNormal" -> NCFG.Node.RandomNormal | "Abs"
-           -> NCFG.Node.Abs | "Log" -> NCFG.Node.Log *)
-        | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o)))
-    in
-    List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns
-  in
-  let c_ops = List.rev @@ fold_nodes_ops_cfg nodes
-  and i_ops, o_ops =
-    ( List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length i_nodes),
-      List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length o_nodes) )
-  in
-  let fold_nodes_attr ns =
-    let get_node_attr n = n.NodeProto.attribute in
-    List.fold ~f:(fun acc n -> get_node_attr n :: acc) ~init:[] ns
-  in
-
-  let c_attr = List.rev @@ fold_nodes_attr nodes
-  and i_attr, o_attr =
-    ( List.init ~f:(fun _ -> [ no_attr ]) (List.length i_nodes),
-      List.init ~f:(fun _ -> [ no_attr ]) (List.length o_nodes) )
-  in
-  let c_nodes_inputs, c_nodes_outputs =
-    List.unzip
-    @@ List.fold
-         ~f:(fun acc n -> (n.NodeProto.input, n.NodeProto.output) :: acc)
-         ~init:[] (List.rev nodes)
-  and i_nodes_inputs, i_nodes_outputs, o_nodes_inputs, o_nodes_outputs =
-    ( List.init ~f:(fun _ -> [ "NO_INPUT" ]) (List.length i_nodes),
-      List.init ~f:(fun _ -> [ "" ]) (List.length i_nodes),
-      List.init ~f:(fun _ -> [ "" ]) (List.length o_nodes),
-      List.init ~f:(fun _ -> [ "NO_OUTPUT" ]) (List.length o_nodes) )
-  in
-  let data_dict =
-    let dict_tensors_cfg ts =
-      let get_float_from_index index data sh =
-        let index = Array.to_list index and sh = Array.to_list sh in
-        let pop_sh = List.tl_exn sh @ [ 1 ] in
-        (* Returns the factors by which multiply each coordinate *)
-        let rec get_factors_from_sh sh_f l =
-          match sh_f with
-          | [] -> List.rev l
-          | _ ->
-            get_factors_from_sh (List.tl_exn sh_f)
-              (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l)
-        in
-        let factors = get_factors_from_sh pop_sh [] in
-        let coord_in_data =
-          List.fold2_exn ~f:(fun x y z -> x + (y * z)) ~init:0 index factors
-        in
-        match data with
-        | Raw raw ->
-          let offset = 4 * coord_in_data in
-          (* Each float is coded on 4 bytes*)
-          let res = EndianBytes.LittleEndian.get_float raw offset in
-          res
-        | Float f -> List.nth_exn f coord_in_data
-      in
-      let build_tensor_from_data sh data =
-        let open NCFG.Tensor in
-        let sh = Array.of_list @@ sh in
-        let tensor = create sh K_float in
-        let coords = all_coords (get_shape tensor) in
-        let rec init_tensor t idx r =
-          match idx with
-          | x :: y ->
-            let value =
-              get_float_from_index (Array.of_list x) r (get_shape t)
-            in
-            set t (Array.of_list x) value;
-            init_tensor t y r
-          | [] -> t
-        in
-        init_tensor tensor coords data
-      in
-      let t_name x =
-        match x.TensorProto.name with Some n -> n | None -> "C_NODE"
-      in
-      let t_dim x = List.map ~f:Int64.to_int_exn x.TensorProto.dims in
-      let t_data x =
-        match x.TensorProto.raw_data with
-        | Some rd -> Some (build_tensor_from_data (t_dim x) (Raw rd))
-        | None -> (
-          match x.TensorProto.float_data with
-          | [] -> None
-          | f -> Some (build_tensor_from_data (t_dim x) (Float f)))
-      in
-      List.fold
-        ~f:(fun m x -> Map.add_exn m ~key:(t_name x) ~data:(t_data x))
-        ~init:(Map.empty (module String))
-        ts
-    in
-    dict_tensors_cfg initi
-  in
-  let unpack v =
-    match v with
-    | Some v -> v
-    | None -> failwith "Unpack found an unexpected None"
-  in
-  let tensor_list =
-    List.init
-      ~f:(fun i ->
-        match Map.find data_dict (unpack (List.nth_exn i_nodes i)) with
-        | Some v -> v
-        | None -> None)
-      (List.length i_nodes)
-  in
-  let tensor_list_full = Map.to_alist data_dict in
-  let tensor_list_rev = List.rev tensor_list in
-  let vip_dims v =
-    let val_t =
-      match v.ValueInfoProto.type' with
-      | Some t -> t
-      | None -> failwith "No type in value info"
-    in
-    let tns_t =
-      match val_t.TypeProto.value with
-      | `Tensor_type t -> t
-      | `not_set ->
-        failwith "No tensor type in value info"
-        (* TODO: support more tensor types *)
-      | _ -> raise (ParseError "Unknown tensor type")
-    in
-    let tns_s =
-      match tns_t.shape with
-      | Some s -> s
-      | None -> failwith "No tensor shape in value info"
-    in
-    List.rev
-    @@ List.fold tns_s ~init:[] ~f:(fun acc d ->
-         match d.value with
-         | `Dim_value d -> d :: acc
-         | `not_set | _ -> 0L :: acc)
-  in
-
-  let c_tensordim_list = List.init (List.length nodes) ~f:(fun _ -> [])
-  and c_tensorraw_list = List.init (List.length nodes) ~f:(fun _ -> None)
-  and o_tensordim_list =
-    List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] outputs
-  and o_tensorraw_list = List.init (List.length o_nodes) ~f:(fun _ -> None)
-  and i_tensordim_list =
-    List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] inputs
-  and i_tensorraw_list = tensor_list_rev in
-  let nodes_names = i_nodes @ c_nodes @ o_nodes in
-  let ops = i_ops @ c_ops @ o_ops in
-  let attrs = i_attr @ c_attr @ o_attr in
-  let prevs_list = i_nodes_inputs @ c_nodes_inputs @ o_nodes_inputs in
-  let nexts_list = i_nodes_outputs @ c_nodes_outputs @ o_nodes_outputs in
-  let tensor_dims =
-    List.map
-      ~f:(List.map ~f:Int64.to_int_exn)
-      (i_tensordim_list @ c_tensordim_list @ o_tensordim_list)
-  in
-  let tensors = i_tensorraw_list @ c_tensorraw_list @ o_tensorraw_list in
-  let operator_parameters (attr : AttributeProto.t list) op =
-    match op with
-    | NCFG.Node.Transpose ->
-      let ints_params =
-        Array.map ~f:Int64.to_int_exn
-        @@ Array.of_list (List.nth_exn attr 0).ints
-      in
-      Some (NCFG.Node.Transpose_params ints_params)
-    (*TODO: maxpool and conv operators: match attr.name in attributes to
-     * create the correct value for each attribute*)
-    (* | NCFG.Vertex.MaxPool -> *)
-    (* | NCFG.Vertex.Conv -> *)
-    | _ -> None
-  in
-  let rec build_op_param_list attrs ops l =
-    match (attrs, ops) with
-    | a :: b, c :: d -> build_op_param_list b d (operator_parameters a c :: l)
-    | [], [] ->
-      List.rev l
-      (*All other list constructions are folding right, so we need to put a
-        final revert *)
-    | _ ->
-      raise (ParseError "Operator and attribute lists have not the same size")
-  in
-  let op_params_cfg = build_op_param_list attrs ops [] in
-  let cfg = G.init_cfg in
-  (* adding inputs, outputs and cnodes to the cfg *)
-  let unkerasize l = List.map ~f:(fun x -> if x = 0 then 1 else x) l in
-  for i = 0 to List.length nodes_names - 1 do
-    let (v : G.V.t) =
-      NCFG.Node.create ~id:i
-        ~name:(List.nth_exn nodes_names i)
-        ~sh:(Array.of_list @@ unkerasize (List.nth_exn tensor_dims i))
-        ~op:(List.nth_exn ops i)
-        ~op_p:(List.nth_exn op_params_cfg i)
-        ~pred:(List.nth_exn prevs_list i)
-        ~succ:(List.nth_exn nexts_list i)
-        ~tensor:(List.nth_exn tensors i)
-    in
-    G.add_vertex cfg v
-  done;
-  (* Adding edges between vertices *)
-  (* For each unnamed vertex (= a calculation node) in the cfg ... *)
-  (* return true if l1 has at least one common element wih l2 *)
-  let rec shared_elm l1 l2 =
-    match l1 with
-    | x :: y -> List.mem l2 x ~equal:String.equal || shared_elm y l2
-    | [] -> false
-  in
-  List.iter
-    ~f:(fun (v : G.V.t) ->
-      match v.name with
-      | None ->
-        let pred = v.pred and succ = v.succ in
-        let prev_v =
-          (* ... get all vertices in cfg that have the current vertex preds
-           * in their succ (at least one of their succ is inside our preds )*)
-          G.find_vertices cfg (fun (x : G.V.t) ->
-            if shared_elm pred x.succ then true else false)
-        (* ... get all vertices in cfg that have the current vertex preds
-         * in their name (they are named the same as one of our preds )*)
-        and named_pred =
-          G.find_vertices cfg (fun (x : G.V.t) ->
-            match x.name with
-            | Some name -> if shared_elm pred [ name ] then true else false
-            | None -> false)
-        (* ... get all vertices in cfg that have the current vertex succ
-         * in their name (they are named the same as one of our succs )*)
-        and named_succ =
-          G.find_vertices cfg (fun (x : G.V.t) ->
-            match x.name with
-            | Some name -> if shared_elm succ [ name ] then true else false
-            | None -> false)
-        (* get all vertices in cfg that have the current vertex succs
-         * in their preds (at least one of their preds is inside our succ )*)
-        and next_v =
-          G.find_vertices cfg (fun (x : G.V.t) ->
-            if shared_elm succ x.pred then true else false)
-        in
-        (* add edges between current vertex and identified preds and succs *)
-        let v_predecessors = prev_v @ named_pred
-        and v_successors = next_v @ named_succ in
-        let unpack_tname (x : G.V.t) =
-          match x.NCFG.Node.name with Some n -> n | None -> ""
-        in
-        List.iter
-          ~f:(fun (x : G.V.t) ->
-            let label =
-              match List.nth x.succ 0 with
-              | Some "NO_OUTPUT" ->
-                let pred_name = unpack_tname x in
-                if List.mem ~equal:String.equal v.NCFG.Node.pred pred_name
-                then pred_name
-                else ""
-              | Some l -> l
-              | None -> ""
-            in
-            G.add_edge_e cfg (x, label, v))
-          v_predecessors;
-        (* add successors edges after filtering those
-         * that are already an edge*)
-        List.iter
-          ~f:(fun (x : G.V.t) ->
-            let all_preds = G.preds cfg x and all_succs = G.succs cfg x in
-            if List.mem ~equal:NCFG.Node.equal all_preds v
-               || List.mem ~equal:NCFG.Node.equal all_succs v
-            then ()
-            else
-              let label =
-                match List.nth_exn x.pred 0 with
-                | "NO_INPUT" ->
-                  let succ_name = unpack_tname x in
-                  if List.mem ~equal:String.equal v.NCFG.Node.succ succ_name
-                  then succ_name
-                  else ""
-                | l -> l
-              in
-              G.add_edge_e cfg (v, label, x))
-          v_successors
-      | _ -> ())
-    (G.vertex_list cfg);
-
-  (*rationale of the following:
-   * PyTorch stores network nodes in the field "inputs" of
-   * the ONNX graph g, and network parameters as a list of tensors
-   * in the ONNX initializer_.
-   * To make the two correspond, elements of g.inputs and g.initializer_
-   * share the same value in the field "name".
-   * In Keras, elements of g.initializer_ have a name, but they do not
-   * correspond to any name in g.inputs.
-   * What we did before was then to create the actual nier cfg following the
-   * PyTorch way.
-   * Below, we complete the cfg with keras data by doing the following:
-   *  * create a node for NIER for each tensor in onnx initializer_
-   *  * for each NIER node, check if there is a node sharing the same name
-   *    pred
-   *  * if yes, remove the one with highest ID (those are initi nodes, but since
-   *  there is already a node in CFG with this name we do not
-   *  need those)
-   *  * if not, for each NIER node, chck if there is a node
-   *  which name is contained in prevs. add it to the prev
-   * *)
-
-  (* adding initi vertices to the cfg *)
-  for i = 0 to List.length tensor_list_full - 1 do
-    let shape =
-      match snd (List.nth_exn tensor_list_full i) with
-      | Some t -> unkerasize (Array.to_list @@ NCFG.Tensor.get_shape t)
-      | None -> []
-    in
-    let (v : G.V.t) =
-      NCFG.Node.create
-        ~id:(i + List.length nodes_names)
-        ~name:(Some (fst (List.nth_exn tensor_list_full i)))
-        ~sh:(Array.of_list @@ unkerasize shape)
-        ~op:NO_OP ~op_p:None ~pred:[] ~succ:[]
-        ~tensor:(snd (List.nth_exn tensor_list_full i))
-    in
-    G.add_vertex cfg v
-  done;
-  (* build a list of nodes
-   * sharing name but with different ids *)
-  let same_name_diff_ids =
-    let aux (x : G.V.t) =
-      G.fold_vertex
-        (fun y acc ->
-          match (x.name, y.name) with
-          | Some xa, Some ya ->
-            if (not (y.id = x.id)) && String.equal xa ya
-            then (x, y) :: acc
-            else acc
-          | _ -> acc)
-        cfg []
-    in
-    G.fold_vertex (fun x l -> aux x :: l) cfg []
-  in
-  let highest_ids =
-    List.fold
-      ~f:(fun acc x ->
-        match x with
-        | a :: _ ->
-          let maxval = max (fst a).NCFG.Node.id (snd a).NCFG.Node.id in
-          maxval :: acc
-        | [] -> acc)
-      ~init:[] same_name_diff_ids
-  in
-  (* (* removing nodes with highest id, those are the*) (* * ones we just added
-     *)*)
-  List.iter
-    ~f:(fun x ->
-      match x with
-      | l :: _ ->
-        let v1 = fst l in
-        if List.mem ~equal:( = ) highest_ids v1.NCFG.Node.id
-        then
-          (* Printf.printf "Removing id %d \n%!" *)
-          (*   v1.NCFG.Vertex.id; *)
-          G.remove_vertex cfg v1
-        else ()
-      | [] -> ())
-    same_name_diff_ids;
-  (* Now it is Keras time.
-   * Look for nodes sharing name and preds,
-   * then create edge *)
-  let shared_name_preds =
-    let aux (x : G.V.t) =
-      match x.name with
-      (* look in other vertices if name is among
-       * predecessors *)
-      | Some n -> G.find_vertices cfg (fun x -> shared_elm [ n ] x.pred)
-      | None -> []
-    in
-    G.fold_vertex (fun x l -> (x, aux x) :: l) cfg []
-  in
-  List.iter
-    ~f:(fun x ->
-      let orgn = fst x and to_edge = snd x in
-      List.iter
-        ~f:(fun t ->
-          if not (G.mem_edge cfg orgn t)
-          then G.add_edge_e cfg (orgn, unpack orgn.NCFG.Node.name, t)
-          else ())
-        to_edge)
-    shared_name_preds;
-  (* else (); *)
-  cfg
-
-let nier_of_onnx_protoc (model : Oprotom.t) =
-  match model.graph with
-  | Some g -> produce_cfg g
-  | None -> raise (ParseError "No graph in ONNX input file found")
-
-let default_opset_info =
-  let open Oproto.Onnx in
-  let onnx_domain = "" in
-  OperatorSetIdProto.make ~domain:onnx_domain ~version:8L ()
-
-let nier_to_onnx_protoc nier =
-  (* TODO: get tensor data, and operator params *)
-  let vertices = G.vertex_list nier in
-  let open NCFG.Node in
-  let protocs =
-    (* match on names of NO_OP nodes and add their outputs to corresponding
-     * C_NODEs inputs *)
-    let vertex_to_protoc v =
-      let name = get_name v in
-      let input, output = (get_pred_list v, get_succ_list v) in
-      let node, initi =
-        match get_op v with
-        | NO_OP | RW_Linearized_ReLu ->
-          (* ONNX initializers are named ONNX Tensor.
-           * If an initializer's name matches an existing
-           * ONNX node input name, the initializer will be assigned as
-           * the input of the node. *)
-          let initi =
-            match get_tensor v with
-            | None -> None
-            | Some t ->
-              Some
-                (Oproto.Onnx.TensorProto.make ~data_type:1
-                   ~dims:
-                     (Array.to_list @@ Array.map ~f:Int64.of_int
-                    @@ NCFG.Tensor.get_shape t)
-                   ~float_data:(NCFG.Tensor.flatten t) ~name ())
-          in
-          let node = None in
-          (node, initi)
-        | _ ->
-          let op_type = str_op (get_op v) in
-          let attribute =
-            match v.operator_parameters with
-            | None | Some (RW_Linearized_ReLu_params _) -> []
-            | Some
-                (Pool_params
-                  (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
-            | Some
-                (Conv_params
-                  (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d)))
-              ->
-              let ksize =
-                Oproto.Onnx.AttributeProto.make ~name:"ksize"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int k)
-                  ()
-              in
-              let stride =
-                Oproto.Onnx.AttributeProto.make ~name:"stride"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s)
-                  ()
-              in
-              let pads =
-                Oproto.Onnx.AttributeProto.make ~name:"pads"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int p)
-                  ()
-              in
-              let dilations =
-                Oproto.Onnx.AttributeProto.make ~name:"dilations"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int d)
-                  ()
-              in
-              [ ksize; stride; pads; dilations ]
-            | Some (Transpose_params s) ->
-              [
-                Oproto.Onnx.AttributeProto.make ~name:"perms"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s)
-                  ();
-              ]
-            | Some (Gather_params a) ->
-              [
-                Oproto.Onnx.AttributeProto.make ~name:"axis"
-                  ~ints:[ Int64.of_int a ]
-                  ();
-              ]
-            | Some (ReduceSum_params (a, b)) ->
-              [
-                Oproto.Onnx.AttributeProto.make ~name:"keepdims"
-                  ~i:(Int64.of_int a) ();
-                Oproto.Onnx.AttributeProto.make ~name:"noop_with_empty_axes"
-                  ~i:(Int64.of_int b) ();
-              ]
-            | Some (RandomNormal_params (a, b, c, d, s)) ->
-              [
-                Oproto.Onnx.AttributeProto.make ~name:"dtype"
-                  ~i:(Int64.of_int a) ();
-                Oproto.Onnx.AttributeProto.make ~name:"mean" ~f:b ();
-                Oproto.Onnx.AttributeProto.make ~name:"scale" ~f:c ();
-                Oproto.Onnx.AttributeProto.make ~name:"seed" ~f:d ();
-                Oproto.Onnx.AttributeProto.make ~name:"shape"
-                  ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s)
-                  ();
-              ]
-            | _ -> []
-          in
-          let node =
-            Some
-              (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type
-                 ~attribute ~doc_string:"" ())
-          in
-          let initi = None in
-          (node, initi)
-      in
-      (node, initi)
-    in
-    List.fold ~init:([], [])
-      ~f:(fun (accn, acci) v ->
-        let node, initi = vertex_to_protoc v in
-        match (node, initi) with
-        | Some n, Some t -> (n :: accn, t :: acci)
-        | Some n, None -> (n :: accn, acci)
-        | None, Some t -> (accn, t :: acci)
-        | None, None -> (accn, acci))
-      vertices
-  in
-  let docstr =
-    "This ONNX model was generated from the Neural Intermediate Representation \
-     of CAISAR"
-  in
-  let protog =
-    Oproto.Onnx.GraphProto.make ~name:"ONNX CAISAR Export" ~node:(fst protocs)
-      ~initializer':(snd protocs) ~sparse_initializer:[]
-      ~doc_string:"ONNX graph generated from CAISAR NIER" ~input:[] ~output:[]
-      ~value_info:[] ~quantization_annotation:[] ()
-  in
-  let protom =
-    Oproto.Onnx.ModelProto.make ~ir_version:8L
-      ~opset_import:[ default_opset_info ] ~producer_name:"CAISAR"
-      ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr
-      ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] ()
-  in
-  let writer = Oprotom.to_proto protom in
-  Ocaml_protoc_plugin.Writer.contents writer
-
-let write_nier_to_onnx nier out_channel =
-  let onnx = nier_to_onnx_protoc nier in
-  Stdio.Out_channel.output_string out_channel onnx
-
-let parse_in_channel in_channel =
-  let open Result in
-  try
-    let buf = Stdio.In_channel.input_all in_channel in
-    let reader = Ocaml_protoc_plugin.Reader.create buf in
-    match Oprotom.from_proto reader with
-    | Ok r ->
-      let n_inputs, n_outputs = get_input_output_dim r in
-      let nier =
-        try Ok (nier_of_onnx_protoc r) with
-        | ParseError s | Sys_error s -> Error s
-        | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
-      in
-      Ok { n_inputs; n_outputs; nier }
-    | _ -> Error "Cannot read protobuf"
-  with
-  | Sys_error s -> Error s
-  | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
-
-let parse filename =
-  let in_channel = Stdlib.open_in filename in
-  Fun.protect
-    ~finally:(fun () -> Stdlib.close_in in_channel)
-    (fun () -> parse_in_channel in_channel)
-
-let write nier filename =
-  let out_chan = Stdlib.open_out filename in
-  write_nier_to_onnx nier out_chan;
-  Stdlib.close_out out_chan
-
-module Simple = Simple
diff --git a/lib/onnx/simple.ml b/lib/onnx/reader.ml
similarity index 55%
rename from lib/onnx/simple.ml
rename to lib/onnx/reader.ml
index 268ac5469fe49e3a76f6206fdd669f2cc5683dc7..739c0c7663655e8baa0cc94b7cd7f4b49360d223 100644
--- a/lib/onnx/simple.ml
+++ b/lib/onnx/reader.ml
@@ -25,24 +25,22 @@ module Format = Stdlib.Format
 module Fun = Stdlib.Fun
 module Oproto = Onnx_protoc (* Autogenerated during compilation *)
 module Oprotom = Oproto.Onnx.ModelProto
-module NCFG = Ir.Nier_simple
-module G = NCFG.GFloat
 
 exception ParseError of string
 
 type t = {
   n_inputs : int; (* Number of inputs. *)
   n_outputs : int; (* Number of outputs. *)
-  nier : (G.t, string) Result.t; (* Intermediate representation. *)
+  nir : (Nir.Ngraph.t, string) Result.t; (* Intermediate representation. *)
 }
 
 (* ONNX format handling. *)
 module Convert : sig
-  val nier_of_onnx_protoc : Oproto.Onnx.ModelProto.t -> G.t
+  val nir_of_onnx_protoc : Oproto.Onnx.ModelProto.t -> Nir.Ngraph.t
   val get_input_output_dim : Oproto.Onnx.ModelProto.t -> int * int
 end = struct
   let get_shape_of_dims (s : Oproto.Onnx.TensorShapeProto.t) =
-    Ir.Nier_simple.Shape.of_list
+    Nir.Shape.of_list
     @@ List.map s ~f:(function
          | { value = `Dim_value v; _ } -> Int64.to_int_exn v
          | { value = `Dim_param _; _ } -> failwith "Parameteric shape"
@@ -62,7 +60,7 @@ end = struct
     | _ -> []
 
   let flattened_dim (s : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
-    Ir.Nier_simple.Shape.size (get_shape_of_dims s)
+    Nir.Shape.size (get_shape_of_dims s)
 
   let get_input_output_dim (model : Oprotom.t) =
     let input_shape, output_shape =
@@ -76,13 +74,10 @@ end = struct
     let output_flat_dim = flattened_dim output_shape in
     (input_flat_dim, output_flat_dim)
 
-  let convert_tensor (ts : Oproto.Onnx.TensorProto.t) :
-    Ir.Nier_simple.GenTensor.t =
+  let convert_tensor (ts : Oproto.Onnx.TensorProto.t) : Nir.Gentensor.t =
     let open Oproto.Onnx in
-    let dims =
-      Ir.Nier_simple.Shape.of_list @@ List.map ~f:Int64.to_int_exn ts.dims
-    in
-    let size = Ir.Nier_simple.Shape.size dims in
+    let dims = Nir.Shape.of_list @@ List.map ~f:Int64.to_int_exn ts.dims in
+    let size = Nir.Shape.size dims in
     let read_raw ~get kind =
       match ts.raw_data with
       | None ->
@@ -93,7 +88,7 @@ end = struct
           let v = get data i in
           Bigarray.Array1.set t i v
         done;
-        Ir.Nier_simple.Tensor.of_array1 dims t
+        Nir.Tensor.of_array1 dims t
     in
     let read_gen ~get elt_size kind custom_data =
       match custom_data with
@@ -106,7 +101,7 @@ end = struct
       | l ->
         let t = Bigarray.(Array1.create kind c_layout size) in
         List.iteri l ~f:(fun i f -> Bigarray.Array1.set t i f);
-        Ir.Nier_simple.Tensor.of_array1 dims t
+        Nir.Tensor.of_array1 dims t
     in
     match Option.map ~f:TensorProto.DataType.from_int ts.data_type with
     | None -> failwith "TensorProto should have a type"
@@ -173,7 +168,7 @@ end = struct
       | _ -> failwith "graph with more than one input node (unsupported)"
     in
     Hashtbl.add_exn converted ~key:input_name
-      ~data:(Ir.Nier_simple.Node.create (Input { shape = input_shape }));
+      ~data:(Nir.Node.create (Input { shape = input_shape }));
     (* converter *)
     let rec convert output =
       Hashtbl.findi_or_add ~default:convert_aux converted output
@@ -194,50 +189,37 @@ end = struct
             (module String)
             (List.map ~f:(fun a -> (Option.value_exn a.name, a)) n.attribute)
         in
-        let get_attr ?default name m =
-          match Hashtbl.find attrs name with
-          | Some v -> m v
-          | None -> (
-            match default with
-            | Some v -> v
-            | None -> Fmt.failwith "Required attribute %s missing" name)
-        in
-        let get_float ?default name : float =
-          get_attr ?default name (function
-            | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ }
-              ->
-              f
-            | _ -> failwith "Attribute wrongly typed")
+        let get_float name : float =
+          match Hashtbl.find_exn attrs name with
+          | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ }
+            ->
+            f
+          | _ -> failwith "Attribute wrongly typed"
         in
-        let get_int ?default name : int =
-          get_attr ?default name (function
-            | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ }
-              ->
-              Int64.to_int_exn i
-            | _ -> failwith "Attribute wrongly typed")
+        let get_int name : int =
+          match Hashtbl.find_exn attrs name with
+          | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } ->
+            Int64.to_int_exn i
+          | _ -> failwith "Attribute wrongly typed"
         in
-        let get_ints ?default name : int list =
-          get_attr ?default name (function
-            | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } ->
-              List.map ~f:Int64.to_int_exn l
-            | _ -> failwith "Attribute wrongly typed")
+        let get_ints name : int list =
+          match Hashtbl.find_exn attrs name with
+          | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } ->
+            List.map ~f:Int64.to_int_exn l
+          | _ -> failwith "Attribute wrongly typed"
         in
-        let get_bool ?default name : bool =
-          get_attr ?default name (function
-            | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ }
-              ->
-              not (Int64.equal i 0L)
-            | _ -> failwith "Attribute wrongly typed")
+        let get_bool name : bool =
+          match Hashtbl.find_exn attrs name with
+          | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } ->
+            not (Int64.equal i 0L)
+          | _ -> failwith "Attribute wrongly typed"
         in
-        let get_tensor ?default name : Ir.Nier_simple.GenTensor.t =
-          get_attr ?default name (function
-            | {
-                type' = Some AttributeProto.AttributeType.TENSOR;
-                t = Some t;
-                _;
-              } ->
-              convert_tensor t
-            | _ -> failwith "Attribute wrongly typed")
+        let get_tensor name : Nir.Gentensor.t =
+          match Hashtbl.find_exn attrs name with
+          | { type' = Some AttributeProto.AttributeType.TENSOR; t = Some t; _ }
+            ->
+            convert_tensor t
+          | _ -> failwith "Attribute wrongly typed"
         in
         let n' =
           match n.op_type with
@@ -246,26 +228,22 @@ end = struct
             match s with
             | "Add" ->
               let input1, input2 = two_arg n.input in
-              Ir.Nier_simple.Add
-                { input1 = convert input1; input2 = convert input2 }
+              Nir.Node.Add { input1 = convert input1; input2 = convert input2 }
             | "Sub" ->
               let input1, input2 = two_arg n.input in
-              Ir.Nier_simple.Sub
-                { input1 = convert input1; input2 = convert input2 }
+              Nir.Node.Sub { input1 = convert input1; input2 = convert input2 }
             | "Mul" ->
               let input1, input2 = two_arg n.input in
-              Ir.Nier_simple.Mul
-                { input1 = convert input1; input2 = convert input2 }
+              Nir.Node.Mul { input1 = convert input1; input2 = convert input2 }
             | "Div" ->
               let input1, input2 = two_arg n.input in
-              Ir.Nier_simple.Div
-                { input1 = convert input1; input2 = convert input2 }
+              Nir.Node.Div { input1 = convert input1; input2 = convert input2 }
             | "Relu" ->
               let input1 = one_arg n.input in
-              Ir.Nier_simple.ReLu { input = convert input1 }
+              Nir.Node.ReLu { input = convert input1 }
             | "MatMul" ->
               let input1, input2 = two_arg n.input in
-              Ir.Nier_simple.Matmul
+              Nir.Node.Matmul
                 { input1 = convert input1; input2 = convert input2 }
             | "Gemm" ->
               let inputA, inputB, inputC =
@@ -274,19 +252,19 @@ end = struct
                 | [ inputA; inputB; inputC ] -> (inputA, inputB, Some inputC)
                 | _ -> failwith "Gemm must have 2 or 3 inputs"
               in
-              Ir.Nier_simple.Gemm
+              Nir.Node.Gemm
                 {
                   inputA = convert inputA;
                   inputB = convert inputB;
                   inputC = Option.map ~f:convert inputC;
-                  alpha = get_float ~default:1.0 "alpha";
-                  beta = get_float ~default:1.0 "beta";
-                  transA = get_bool ~default:false "transA";
-                  transB = get_bool ~default:false "transB";
+                  alpha = get_float "alpha";
+                  beta = get_float "beta";
+                  transA = get_bool "transA";
+                  transB = get_bool "transB";
                 }
-            | "LogSoftmax" -> Ir.Nier_simple.LogSoftmax
+            | "LogSoftmax" -> Nir.Node.LogSoftmax
             | "Transpose" ->
-              Ir.Nier_simple.Transpose
+              Nir.Node.Transpose
                 { input = convert (one_arg n.input); perm = get_ints "perm" }
             | "Squeeze" ->
               let data, axes =
@@ -295,7 +273,7 @@ end = struct
                 | [ data; axes ] -> (convert data, Some (convert axes))
                 | _ -> failwith "Squeeze must have 1 or 2 inputs"
               in
-              Ir.Nier_simple.Squeeze { data; axes }
+              Nir.Node.Squeeze { data; axes }
             | "MaxPool" -> MaxPool
             | "Constant" -> Constant { data = get_tensor "value" }
             | "Conv" -> Conv
@@ -304,10 +282,10 @@ end = struct
                 { input = convert @@ one_arg n.input; axis = get_int "axis" }
             (* | "Reshape" -> NCFG.Node.Reshape | "Identity" ->
                NCFG.Node.Identity | "Gather" -> NCFG.Node.Gather *)
-            | "Abs" -> Ir.Nier_simple.Abs { input = convert @@ one_arg n.input }
-            | "Log" -> Ir.Nier_simple.Log { input = convert @@ one_arg n.input }
+            | "Abs" -> Nir.Node.Abs { input = convert @@ one_arg n.input }
+            | "Log" -> Nir.Node.Log { input = convert @@ one_arg n.input }
             | "RandomNormal" ->
-              Ir.Nier_simple.RandomNormal
+              Nir.Node.RandomNormal
                 {
                   dtype = get_int "dtype";
                   mean = get_float "mean";
@@ -318,15 +296,14 @@ end = struct
             (* TODO: ReduceSum, GatherND *)
             | s -> failwith (Printf.sprintf "Unknown operators %s" s))
         in
-        Ir.Nier_simple.Node.create n'
-      | Tensor t ->
-        Ir.Nier_simple.Node.create (Constant { data = convert_tensor t })
+        Nir.Node.create n'
+      | Tensor t -> Nir.Node.create (Constant { data = convert_tensor t })
     in
     let output' = convert output in
-    assert (Ir.Nier_simple.Shape.equal output'.shape output_shape);
-    Ir.Nier_simple.create output'
+    assert (Nir.Shape.equal output'.shape output_shape);
+    Nir.Ngraph.create output'
 
-  let nier_of_onnx_protoc (model : Oprotom.t) =
+  let nir_of_onnx_protoc (model : Oprotom.t) =
     (match model.ir_version with
     | None -> failwith "IR version not specified"
     | Some (3L | 4L | 5L | 6L | 7L | 8L) -> ()
@@ -351,151 +328,19 @@ let parse_in_channel in_channel =
     match Oprotom.from_proto reader with
     | Ok r ->
       let n_inputs, n_outputs = Convert.get_input_output_dim r in
-      let nier =
-        try Ok (Convert.nier_of_onnx_protoc r) with
+      let nir =
+        try Ok (Convert.nir_of_onnx_protoc r) with
         | ParseError s | Sys_error s -> Error s
         | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
       in
-      Ok { n_inputs; n_outputs; nier }
+      Ok { n_inputs; n_outputs; nir }
     | _ -> Error "Cannot read protobuf"
   with
   | Sys_error s -> Error s
   | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
 
-let parse filename =
+let from_file filename =
   let in_channel = Stdlib.open_in filename in
   Fun.protect
     ~finally:(fun () -> Stdlib.close_in in_channel)
     (fun () -> parse_in_channel in_channel)
-
-let value_info_from_tensor_shape ~name (n : Ir.Nier_simple.Node.t) =
-  let open Oproto.Onnx in
-  let dim =
-    List.map (Ir.Nier_simple.Shape.to_list n.shape) ~f:(fun i ->
-      TensorShapeProto.Dimension.make ~value:(`Dim_value (Int64.of_int i)) ())
-  in
-  let shape = TensorShapeProto.make ~dim () in
-  let ty : TensorProto.DataType.t =
-    match n.ty with Float -> FLOAT | Int64 -> INT64
-  in
-  let value =
-    `Tensor_type
-      (TypeProto.Tensor.make
-         ~elem_type:TensorProto.DataType.(to_int ty)
-         ~shape ())
-  in
-  let type' = TypeProto.make ~value () in
-  let value_info = ValueInfoProto.make ~name ~type' () in
-  value_info
-
-let convert_into_tensor ?name (t : Ir.Nier_simple.GenTensor.t) =
-  let mk data_type =
-    Oproto.Onnx.TensorProto.make
-      ~data_type:Oproto.Onnx.TensorProto.DataType.(to_int data_type)
-      ~dims:
-        (List.map ~f:Int64.of_int @@ Ir.Nier_simple.Shape.to_list
-        @@ Ir.Nier_simple.GenTensor.shape t)
-      ?name
-  in
-  match t with
-  | Float t -> mk FLOAT ~float_data:(Ir.Nier_simple.Tensor.flatten t) ()
-  | Int64 t -> mk INT64 ~int64_data:(Ir.Nier_simple.Tensor.flatten t) ()
-
-let default_opset_import =
-  let open Oproto.Onnx in
-  let onnx_domain = "" in
-  OperatorSetIdProto.make ~domain:onnx_domain ~version:13L ()
-
-let nier_simple_to_onnx_protoc (nier_simple : Ir.Nier_simple.GFloat.t) =
-  let open Oproto.Onnx in
-  let get_name (v : Ir.Nier_simple.Node.t) = Int.to_string v.id in
-  let protocs, input =
-    let acc = Queue.create () in
-    let g_input = ref None in
-    let vertex_to_protoc (v : Ir.Nier_simple.Node.t) =
-      let name = get_name v in
-      let input = List.map ~f:get_name (Ir.Nier_simple.Node.preds v) in
-      let output = [ name ] in
-      let make op_type attribute =
-        Queue.enqueue acc
-          (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~attribute
-             ~doc_string:"" ())
-      in
-      let mk_int name i =
-        AttributeProto.make ~name ~type':INT ~i:(Int64.of_int i) ()
-      in
-      let mk_ints name ints =
-        AttributeProto.make ~name ~type':INTS
-          ~ints:(List.map ~f:Int64.of_int ints)
-          ()
-      in
-      let mk_float name f = AttributeProto.make ~name ~type':FLOAT ~f () in
-      let mk_tensor name t = AttributeProto.make ~name ~type':TENSOR ~t () in
-      match v.descr with
-      | Gemm _ | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv
-      | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ ->
-        failwith (Fmt.str "Not implemented export: %a" Ir.Nier_simple.Node.pp v)
-      | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ]
-      | Reshape _ -> make "Reshape" []
-      | Constant { data } ->
-        let data = convert_into_tensor data in
-        make "Constant" [ mk_tensor "value" data ]
-      | Add _ -> make "Add" []
-      | Sub _ -> make "Sub" []
-      | Mul _ -> make "Mul" []
-      | Div _ -> make "Div" []
-      | Matmul _ -> make "MatMul" []
-      | ReLu _ -> make "Relu" []
-      | Input _ -> g_input := Some v
-      | Concat { axis; _ } -> make "Concat" [ mk_int "axis" axis ]
-      | Gather { axis; _ } -> make "Gather" [ mk_int "axis" axis ]
-      | Abs _ -> make "Abs" []
-      | Log _ -> make "Log" []
-      | RandomNormal { dtype; mean; scale; seed; shape } ->
-        make "RandomNormal"
-          [
-            mk_int "dtype" dtype;
-            mk_float "mean" mean;
-            mk_float "scale" scale;
-            mk_float "seed" seed;
-            mk_ints "shape" (Array.to_list shape);
-          ]
-    in
-    Ir.Nier_simple.iter_vertex vertex_to_protoc nier_simple;
-    (Queue.to_list acc, Option.value_exn !g_input)
-  in
-  let docstr =
-    "This ONNX model was generated from the Neural Intermediate Representation \
-     of CAISAR"
-  in
-  let input = [ value_info_from_tensor_shape ~name:(get_name input) input ] in
-  let output =
-    let output = Ir.Nier_simple.output nier_simple in
-    [ value_info_from_tensor_shape ~name:(get_name output) output ]
-  in
-  let value_info =
-    List.map (Ir.Nier_simple.nodes nier_simple) ~f:(fun v ->
-      value_info_from_tensor_shape ~name:(get_name v) v)
-  in
-  let protog =
-    GraphProto.make ~name:"ONNX CAISAR Export" ~node:protocs ~initializer':[]
-      ~sparse_initializer:[] ~doc_string:"ONNX graph generated from CAISAR NIER"
-      ~input ~output ~value_info ~quantization_annotation:[] ()
-  in
-  let protom =
-    Oproto.Onnx.ModelProto.make ~ir_version:8L
-      ~opset_import:[ default_opset_import ] ~producer_name:"CAISAR"
-      ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr
-      ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] ()
-  in
-  let writer = Oprotom.to_proto protom in
-  Ocaml_protoc_plugin.Writer.contents writer
-
-let write_to_onnx nier out_channel =
-  let onnx = nier_simple_to_onnx_protoc nier in
-  Stdio.Out_channel.output_string out_channel onnx
-
-let write nier filename =
-  let out_chan = Stdlib.open_out filename in
-  write_to_onnx nier out_chan;
-  Stdlib.close_out out_chan
diff --git a/lib/onnx/simple.mli b/lib/onnx/reader.mli
similarity index 86%
rename from lib/onnx/simple.mli
rename to lib/onnx/reader.mli
index f2a0528ea8843bb93b8038406063d98145242c67..18e8ddb1925b31480a1498b27f238de049a3c1f1 100644
--- a/lib/onnx/simple.mli
+++ b/lib/onnx/reader.mli
@@ -23,13 +23,9 @@
 type t = private {
   n_inputs : int;  (** Number of inputs. *)
   n_outputs : int;  (** Number of outputs. *)
-  nier : (Ir.Nier_simple.GFloat.t, string) Result.t;
-    (** Intermediate representation. *)
+  nir : (Nir.Ngraph.t, string) Result.t;  (** Intermediate representation. *)
 }
 (** ONNX model metadata and intermediate representation. *)
 
-val parse : string -> (t, string) Result.t
-(** Parse an ONNX file into a NIER. *)
-
-val write : Ir.Nier_simple.GFloat.t -> string -> unit
-(** Write a NIER into an ONNX file. *)
+val from_file : string -> (t, string) Result.t
+(** Parse an ONNX file. *)
diff --git a/lib/onnx/tests/print.ml b/lib/onnx/tests/print.ml
index 12cd6e4f87f5013bcbb53c41f536de3095f07e82..166927b701182cfb950ea25e547757824c856886 100644
--- a/lib/onnx/tests/print.ml
+++ b/lib/onnx/tests/print.ml
@@ -9,14 +9,14 @@ let temporary_file = "out/test.onnx"
 let () =
   match Onnx.Simple.parse file with
   | Error s -> print_endline s
-  | Ok { nier = Error s; _ } -> print_endline s
-  | Ok { nier = Ok g; _ } -> (
+  | Ok { nir = Error s; _ } -> print_endline s
+  | Ok { nir = Ok g; _ } -> (
     print_endline "ok";
     Onnx.Simple.write g temporary_file;
     match Onnx.Simple.parse temporary_file with
     | Error s -> print_endline s
-    | Ok { nier = Error s; _ } -> print_endline s
-    | Ok { nier = Ok _; _ } -> print_endline "ok")
+    | Ok { nir = Error s; _ } -> print_endline s
+    | Ok { nir = Ok _; _ } -> print_endline "ok")
 
 let () =
   let pid =
diff --git a/lib/onnx/writer.ml b/lib/onnx/writer.ml
new file mode 100644
index 0000000000000000000000000000000000000000..6b8d6f3eff756c70a96d77dcdc5deea99e86e0fe
--- /dev/null
+++ b/lib/onnx/writer.ml
@@ -0,0 +1,157 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Base
+module Format = Stdlib.Format
+module Fun = Stdlib.Fun
+module Oproto = Onnx_protoc (* Autogenerated during compilation *)
+module Oprotom = Oproto.Onnx.ModelProto
+
+let value_info_from_tensor_shape ~name ~shape =
+  let open Oproto.Onnx in
+  let dim =
+    List.map (Nir.Shape.to_list shape) ~f:(fun i ->
+      TensorShapeProto.Dimension.make ~value:(`Dim_value (Int64.of_int i)) ())
+  in
+  let shape = TensorShapeProto.make ~dim () in
+  let value =
+    `Tensor_type
+      (TypeProto.Tensor.make
+         ~elem_type:AttributeProto.AttributeType.(to_int FLOAT)
+         ~shape ())
+  in
+  let type' = TypeProto.make ~value () in
+  let value_info = ValueInfoProto.make ~name ~type' () in
+  value_info
+
+let convert_into_tensor ?name (t : Nir.Gentensor.t) =
+  let mk data_type =
+    Oproto.Onnx.TensorProto.make
+      ~data_type:Oproto.Onnx.TensorProto.DataType.(to_int data_type)
+      ~dims:
+        (List.map ~f:Int64.of_int @@ Nir.Shape.to_list @@ Nir.Gentensor.shape t)
+      ?name
+  in
+  match t with
+  | Float t -> mk FLOAT ~float_data:(Nir.Tensor.flatten t) ()
+  | Int64 t -> mk INT64 ~int64_data:(Nir.Tensor.flatten t) ()
+
+let default_opset_import =
+  let open Oproto.Onnx in
+  let onnx_domain = "" in
+  OperatorSetIdProto.make ~domain:onnx_domain ~version:13L ()
+
+let nir_to_onnx_protoc (nir : Nir.Ngraph.t) =
+  let open Oproto.Onnx in
+  let get_name (v : Nir.Node.t) = Int.to_string v.id in
+  let protocs, (input, input_shape) =
+    let acc = Queue.create () in
+    let g_input = ref None in
+    let vertex_to_protoc (v : Nir.Node.t) =
+      let name = get_name v in
+      let input = List.map ~f:get_name (Nir.Node.preds v) in
+      let output = [ name ] in
+      let make op_type attribute =
+        Queue.enqueue acc
+          (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~attribute
+             ~doc_string:"" ())
+      in
+      let mk_int name i =
+        AttributeProto.make ~name ~type':INT ~i:(Int64.of_int i) ()
+      in
+      let mk_ints name ints =
+        AttributeProto.make ~name ~type':INTS
+          ~ints:(List.map ~f:Int64.of_int ints)
+          ()
+      in
+      let mk_float name f = AttributeProto.make ~name ~type':FLOAT ~f () in
+      let mk_tensor name t = AttributeProto.make ~name ~type':TENSOR ~t () in
+      match v.descr with
+      | Gemm _ | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv
+      | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ ->
+        Caisar_logging.Logging.not_implemented_yet (fun m ->
+          m "Operator %a not implemented yet." Nir.Node.pp_descr v.descr)
+      | Reshape _ -> make "Reshape" []
+      | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ]
+      | Constant { data } ->
+        let data = convert_into_tensor data in
+        make "Constant" [ mk_tensor "value" data ]
+      | Add _ -> make "Add" []
+      | Sub _ -> make "Sub" []
+      | Mul _ -> make "Mul" []
+      | Div _ -> make "Div" []
+      | Matmul _ -> make "MatMul" []
+      | ReLu _ -> make "Relu" []
+      | Input { shape } -> g_input := Some (v, shape)
+      | Concat { axis; _ } -> make "Concat" [ mk_int "axis" axis ]
+      | Gather { axis; _ } -> make "Gather" [ mk_int "axis" axis ]
+      | Abs _ -> make "Abs" []
+      | Log _ -> make "Log" []
+      | RandomNormal { dtype; mean; scale; seed; shape } ->
+        make "RandomNormal"
+          [
+            mk_int "dtype" dtype;
+            mk_float "mean" mean;
+            mk_float "scale" scale;
+            mk_float "seed" seed;
+            mk_ints "shape" (Array.to_list shape);
+          ]
+    in
+    Nir.Ngraph.iter_vertex vertex_to_protoc nir;
+    (Queue.to_list acc, Option.value_exn !g_input)
+  in
+  let docstr =
+    "This ONNX model was generated from the Neural Intermediate Representation \
+     of CAISAR"
+  in
+  let input =
+    [ value_info_from_tensor_shape ~name:(get_name input) ~shape:input_shape ]
+  in
+  let output =
+    let output = Nir.Ngraph.output nir in
+    [
+      value_info_from_tensor_shape ~name:(get_name output)
+        ~shape:(Nir.Node.compute_shape output);
+    ]
+  in
+  let protog =
+    GraphProto.make ~name:"ONNX CAISAR Export" ~node:protocs ~initializer':[]
+      ~sparse_initializer:[] ~doc_string:"ONNX graph generated from CAISAR NIR"
+      ~input ~output ~value_info:[] ~quantization_annotation:[] ()
+  in
+  let protom =
+    Oproto.Onnx.ModelProto.make ~ir_version:8L
+      ~opset_import:[ default_opset_import ] ~producer_name:"CAISAR"
+      ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr
+      ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] ()
+  in
+  let writer = Oprotom.to_proto protom in
+  Ocaml_protoc_plugin.Writer.contents writer
+
+let write_to_onnx nir out_channel =
+  let onnx = nir_to_onnx_protoc nir in
+  Stdio.Out_channel.output_string out_channel onnx
+
+let to_file nir filename =
+  let out_chan = Stdlib.open_out filename in
+  write_to_onnx nir out_chan;
+  Stdlib.close_out out_chan
diff --git a/lib/onnx/writer.mli b/lib/onnx/writer.mli
new file mode 100644
index 0000000000000000000000000000000000000000..56945117a968f34ac5998025664d9f31375d5fa8
--- /dev/null
+++ b/lib/onnx/writer.mli
@@ -0,0 +1,24 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2023                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+val to_file : Nir.Ngraph.t -> string -> unit
+(** Write a NIR into an ONNX file. *)
diff --git a/src/dune b/src/dune
index ab868447a863fc15039b8c8b72eae6e0bbefb0cc..3afbd8c904c6c573cbaf7bb2378a4875921f17cb 100644
--- a/src/dune
+++ b/src/dune
@@ -24,6 +24,7 @@
   fpath
   zarith
   caisar_logging
+  caisar.nir
   caisar.xgboost)
  (preprocess
   (pps
diff --git a/src/language.ml b/src/language.ml
index 2e9728c2cbb055985e5ee01bccbf0ced7537c386..960216d13db020663c574df469d5a83fcb12522e 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -32,7 +32,7 @@ type nn_shape = {
   nb_outputs : int;
   ty_data : Ty.ty;
   filename : string;
-  nier : Ir.Nier_simple.t option;
+  nir : Nir.Ngraph.t option;
 }
 
 type svm_shape = {
@@ -46,7 +46,7 @@ let loaded_svms = Term.Hls.create 10
 let lookup_loaded_nets = Term.Hls.find_opt loaded_nets
 let lookup_loaded_svms = Term.Hls.find_opt loaded_svms
 
-let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr =
+let register_nn_as_tuple env nb_inputs nb_outputs filename ?nir mstr =
   let name = "AsTuple" in
   let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
   let nn = Pmodule.read_module env [ "caisar"; "caisar" ] "NN" in
@@ -62,14 +62,14 @@ let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr =
       (Ty.ty_tuple (List.init nb_outputs ~f))
   in
   Term.Hls.add loaded_nets ls_nn_apply
-    { filename; nb_inputs; nb_outputs; ty_data; nier };
+    { filename; nb_inputs; nb_outputs; ty_data; nir };
   let th_uc =
     Pmodule.add_pdecl ~vc:false th_uc
       (Pdecl.create_pure_decl (Decl.create_param_decl ls_nn_apply))
   in
   Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr
 
-let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr =
+let register_nn_as_array env nb_inputs nb_outputs filename ?nir mstr =
   let name = "AsArray" in
   let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
   let nn =
@@ -81,7 +81,7 @@ let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr =
   in
   let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] ty_data in
   Term.Hls.add loaded_nets ls_model
-    { filename; nb_inputs; nb_outputs; ty_data; nier };
+    { filename; nb_inputs; nb_outputs; ty_data; nir };
   let th_uc =
     Pmodule.add_pdecl ~vc:false th_uc
       (Pdecl.create_pure_decl (Decl.create_param_decl ls_model))
@@ -122,21 +122,21 @@ let onnx_parser =
   Env.Wenv.memoize 13 (fun env ->
     let h = Hashtbl.create (module String) in
     Hashtbl.findi_or_add h ~default:(fun filename ->
-      let model = Onnx.Simple.parse filename in
+      let model = Onnx.Reader.from_file filename in
       match model with
       | Error s -> Loc.errorm "%s" s
-      | Ok { n_inputs; n_outputs; nier } ->
-        let nier =
-          match nier with
+      | Ok { n_inputs; n_outputs; nir } ->
+        let nir =
+          match nir with
           | Error msg ->
             Logs.warn (fun m ->
               m "Cannot build network intermediate representation:@ %s" msg);
             None
-          | Ok nier -> Some nier
+          | Ok nir -> Some nir
         in
         Wstdlib.Mstr.empty
-        |> register_nn_as_tuple env n_inputs n_outputs filename ?nier
-        |> register_nn_as_array env n_inputs n_outputs filename ?nier))
+        |> register_nn_as_tuple env n_inputs n_outputs filename ?nir
+        |> register_nn_as_array env n_inputs n_outputs filename ?nir))
 
 let ovo_parser =
   Env.Wenv.memoize 13 (fun env ->
@@ -205,7 +205,7 @@ type nn = {
 
 and nn_format =
   | NNet
-  | ONNX of Ir.Nier_simple.t option [@printer fun fmt _ -> Fmt.pf fmt "<nier>"]
+  | ONNX of Nir.Ngraph.t option [@printer fun fmt _ -> Fmt.pf fmt "<nir>"]
 [@@deriving show]
 
 let nets = Term.Hls.create 10
@@ -236,24 +236,24 @@ let create_nn_nnet env filename =
     }
 
 let create_nn_onnx env filename =
-  let model = Onnx.Simple.parse filename in
+  let model = Onnx.Reader.from_file filename in
   match model with
   | Error s -> Loc.errorm "%s" s
-  | Ok { n_inputs; n_outputs; nier } ->
-    let nier =
-      match nier with
+  | Ok { n_inputs; n_outputs; nir } ->
+    let nir =
+      match nir with
       | Error msg ->
         Logs.warn (fun m ->
           m "Cannot build network intermediate representation:@ %s" msg);
         None
-      | Ok nier -> Some nier
+      | Ok nir -> Some nir
     in
     {
       nn_nb_inputs = n_inputs;
       nn_nb_outputs = n_outputs;
       nn_ty_elt = ty_float64_t env;
       nn_filename = filename;
-      nn_format = ONNX nier;
+      nn_format = ONNX nir;
     }
 
 let create_nn =
diff --git a/src/language.mli b/src/language.mli
index 72be9f3f461a984bfbd9a6351f335b03c69929ee..a2ea56e51e4564e9fedb7d537f7eedc59b2aeed6 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -27,7 +27,7 @@ type nn_shape = private {
   nb_outputs : int;
   ty_data : Ty.ty;
   filename : string;
-  nier : Ir.Nier_simple.t option;
+  nir : Nir.Ngraph.t option;
 }
 
 type svm_shape = private {
@@ -84,7 +84,7 @@ type nn = private {
 
 and nn_format =
   | NNet
-  | ONNX of Ir.Nier_simple.t option
+  | ONNX of Nir.Ngraph.t option
 [@@deriving show]
 
 val create_nn : Env.env -> [ `NNet | `ONNX ] -> string -> Term.lsymbol
diff --git a/src/main.ml b/src/main.ml
index d9b804129873fa285e5bd4ff8bc9ce7626a274b5..906e9712ce9d3c6faf8d1cf8fdd1b2db6a47f708 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -265,7 +265,7 @@ let verify_cmd =
     Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE")
   in
   let onnx_out_dir =
-    let doc = "Write NIER as ONNX file in $(docv)." in
+    let doc = "Write NIR as ONNX file in $(docv)." in
     Arg.(
       value
       & opt (some string) None
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 322bcf3074aebc4231ff3d2f53ad54466fc4b965..22a140e58791d436dedef6a2b50b5bd9977eef29 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -46,7 +46,7 @@ let tempfile =
     | None -> Stdlib.Filename.temp_file "caisar" ".onnx"
 
 let create_new_nn env input_vars outputs : string =
-  let module IR = Ir.Nier_simple in
+  let module IR = Nir in
   let th_f64 = Symbols.Float64.create env in
   let th_model = Symbols.Model.create env in
   let th_nn = Symbols.NN.create env in
@@ -58,7 +58,7 @@ let create_new_nn env input_vars outputs : string =
   let get_input =
     Why3.Term.Hls.memo 10 (fun ls ->
       let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in
-      Ir.Nier_simple.Node.gather_int input i)
+      IR.Node.gather_int input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
   let nn_cache = Stdlib.Hashtbl.create 17 in
@@ -68,7 +68,7 @@ let create_new_nn env input_vars outputs : string =
     let converted_args = List.map ~f:convert_term old_nn_args in
     let id =
       ( old_nn.Language.nn_filename,
-        List.map converted_args ~f:(fun n -> n.Ir.Nier_simple.id) )
+        List.map converted_args ~f:(fun n -> n.Nir.Node.id) )
     in
     match Stdlib.Hashtbl.find_opt nn_cache id with
     | None ->
@@ -77,15 +77,15 @@ let create_new_nn env input_vars outputs : string =
       node_nn
     | Some node_nn -> node_nn
   and convert_old_nn_aux old_nn converted_args =
-    let old_nn_nier =
-      match Onnx.Simple.parse old_nn.Language.nn_filename with
+    let old_nn_nir =
+      match Onnx.Reader.from_file old_nn.Language.nn_filename with
       | Error s ->
         Logging.code_error ~src (fun m ->
           m "No parsed NN '%s': %s" old_nn.nn_filename s)
-      | Ok { nier = Error s; _ } ->
+      | Ok { nir = Error s; _ } ->
         Logging.code_error ~src (fun m ->
-          m "No NIER available for NN '%s': %s" old_nn.nn_filename s)
-      | Ok { nier = Ok g; _ } -> g
+          m "No NIR available for NN '%s': %s" old_nn.nn_filename s)
+      | Ok { nir = Ok g; _ } -> g
     in
     (* Create the graph to replace the old input of the nn *)
     let input () =
@@ -93,18 +93,17 @@ let create_new_nn env input_vars outputs : string =
       let node =
         IR.Node.create (Concat { inputs = converted_args; axis = 0 })
       in
-      Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node
+      IR.Node.reshape (IR.Ngraph.input_shape old_nn_nir) node
     in
     let out =
-      IR.Node.replace_input input (IR.output old_nn_nier)
-      |> Ir.Nier_simple.Node.reshape
-           (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |])
+      IR.Node.replace_input input (IR.Ngraph.output old_nn_nir)
+      |> IR.Node.reshape (Nir.Shape.of_array [| old_nn.nn_nb_outputs |])
     in
     out
   and convert_old_nn_at_old_index old_nn old_index old_nn_args =
     let out = convert_old_nn old_nn old_nn_args in
     let old_index = Why3.Number.to_small_integer old_index in
-    Ir.Nier_simple.Node.gather_int out old_index
+    Nir.Node.gather_int out old_index
   and convert_term term =
     match Why3.Term.Hterm.find_opt cache term with
     | None ->
@@ -112,7 +111,7 @@ let create_new_nn env input_vars outputs : string =
       Why3.Term.Hterm.add cache term n;
       n
     | Some n -> n
-  and convert_term_aux term : IR.GFloat.Node.t =
+  and convert_term_aux term : IR.Node.t =
     if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
     then
       Logging.user_error ?loc:term.t_loc (fun m ->
@@ -122,7 +121,7 @@ let create_new_nn env input_vars outputs : string =
       IR.Node.create
         (Constant
            {
-             data = IR.GenTensor.create_1_float (Utils.float_of_real_constant r);
+             data = IR.Gentensor.create_1_float (Utils.float_of_real_constant r);
            })
     | Tapp (ls, []) -> get_input ls
     | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add ->
@@ -135,12 +134,12 @@ let create_new_nn env input_vars outputs : string =
       match b.t_node with
       | Tconst (Why3.Constant.ConstReal r) ->
         let f = Utils.float_of_real_constant r in
-        Ir.Nier_simple.Node.div_float (convert_term a) f
+        Nir.Node.div_float (convert_term a) f
       | _ ->
         IR.Node.create
           (Div { input1 = convert_term a; input2 = convert_term b }))
     | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg ->
-      Ir.Nier_simple.Node.mul_float (convert_term a) (-1.)
+      Nir.Node.mul_float (convert_term a) (-1.)
     | Tapp
         ( ls_get (* [ ] *),
           [
@@ -195,11 +194,11 @@ let create_new_nn env input_vars outputs : string =
   let output = IR.Node.concat_0 outputs in
   assert (
     IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |]));
-  let nn = IR.create output in
-  Logs.debug ~src:src_show (fun f -> f "@.%s@." (IR.grapheasy nn));
-  Logs.debug ~src (fun f -> f "@[<v>%a@]@." IR.pp_debug nn);
+  let nn = IR.Ngraph.create output in
+  Logs.debug ~src:src_show (fun f -> f "@.%s@." (IR.Ngraph.grapheasy nn));
+  Logs.debug ~src (fun f -> f "@[<v>%a@]@." IR.Ngraph.pp_debug nn);
   let filename = tempfile () in
-  Onnx.Simple.write nn filename;
+  Onnx.Writer.to_file nn filename;
   filename
 
 (* Choose the term pattern for starting the conversion to ONNX. *)
diff --git a/src/transformations/nn2smt.ml b/src/transformations/nn2smt.ml
index 56f24a1324b4d7077d5de3e414b7d134f98959f1..03cba10ecf8a45c67332315556c3a6435d1e9626 100644
--- a/src/transformations/nn2smt.ml
+++ b/src/transformations/nn2smt.ml
@@ -21,7 +21,7 @@
 (**************************************************************************)
 
 open Base
-module IR = Ir.Nier_simple
+module IR = Nir.Ngraph
 
 let src =
   Logs.Src.create "NN2SMT" ~doc:"Encoding of neural networks into SMT-LIB"
@@ -49,13 +49,13 @@ let relu env expr =
   in
   Why3.Term.t_app_infer relu_s [ expr ]
 
-let rec terms_of_nier m d node index ty_inputs env input_vars input_terms =
+let rec terms_of_nir m d node index ty_inputs env input_vars input_terms =
   let mi = Option.value ~default:(Map.empty (module Int)) @@ Map.find !m node in
   match Map.find mi index with
   | Some ls -> ls
   | None ->
     let ls, decl =
-      terms_of_nier_aux m d node index ty_inputs env input_vars input_terms
+      terms_of_nir_aux m d node index ty_inputs env input_vars input_terms
     in
     Queue.enqueue d (Why3.Theory.create_decl decl);
     Queue.enqueue d
@@ -64,14 +64,15 @@ let rec terms_of_nier m d node index ty_inputs env input_vars input_terms =
     m := Map.set !m ~key:node ~data:mi;
     ls
 
-and t_app_of_nier m d node index ty_inputs env input_vars input_terms =
-  let ls = terms_of_nier m d node index ty_inputs env input_vars input_terms in
+and t_app_of_nir m d node index ty_inputs env input_vars input_terms =
+  let ls = terms_of_nir m d node index ty_inputs env input_vars input_terms in
   Why3.Term.fs_app ls
     (if List.is_empty input_vars then [] else input_terms)
     ty_inputs
 
-and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms =
-  let preid = Why3.Ident.id_fresh (Fmt.str "n%i_%i" node.IR.id index) in
+and terms_of_nir_aux m d (node : Nir.Node.t) index ty_inputs env input_vars
+  input_terms =
+  let preid = Why3.Ident.id_fresh (Fmt.str "n%i_%i" node.id index) in
   let ls =
     Why3.Term.create_fsymbol preid
       (List.map input_vars ~f:(fun _ -> ty_inputs))
@@ -82,13 +83,13 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms =
        Paxiom ps t *)
     match node.descr with
     | Constant { data = Float ba } ->
-      let id = IR.Shape.unrow_major node.shape index in
-      Utils.term_of_float env (IR.Tensor.get ba id)
+      let id = Nir.Shape.unrow_major node.shape index in
+      Utils.term_of_float env (Nir.Tensor.get ba id)
     | Matmul { input1; input2 } ->
       (* [|... ; _; n |], [| ...; n; _ |] *)
-      let id = IR.Shape.unrow_major node.shape index in
+      let id = Nir.Shape.unrow_major node.shape index in
       let broadcast shape id_dst =
-        let len = IR.Shape.rank shape in
+        let len = Nir.Shape.rank shape in
         let id_src = Array.create ~len 0 in
         for i = 0 to len - 3 do
           id_src.(i) <- id_dst.(i)
@@ -101,18 +102,18 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms =
       id2.(Array.length id2 - 1) <- id.(Array.length id - 1);
       let acc =
         Sequence.init
-          (IR.Shape.get input1.shape (IR.Shape.rank input1.shape - 1))
+          (Nir.Shape.get input1.shape (Nir.Shape.rank input1.shape - 1))
           ~f:(fun i ->
             (* id1 = [...; _; i ] id2 = [...; i; _ ] *)
             id1.(Array.length id1 - 1) <- i;
             id2.(Array.length id2 - 2) <- i;
-            let i1 = IR.Shape.row_major input1.shape id1 in
-            let i2 = IR.Shape.row_major input2.shape id2 in
+            let i1 = Nir.Shape.row_major input1.shape id1 in
+            let i2 = Nir.Shape.row_major input2.shape id2 in
             let a1 =
-              t_app_of_nier m d input1 i1 ty_inputs env input_vars input_terms
+              t_app_of_nir m d input1 i1 ty_inputs env input_vars input_terms
             in
             let a2 =
-              t_app_of_nier m d input2 i2 ty_inputs env input_vars input_terms
+              t_app_of_nir m d input2 i2 ty_inputs env input_vars input_terms
             in
             mul a1 a2 env)
       in
@@ -121,29 +122,29 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms =
     | Input _ -> List.nth_exn input_terms index
     | Add { input1; input2 } ->
       let t1 =
-        t_app_of_nier m d input1 index ty_inputs env input_vars input_terms
+        t_app_of_nir m d input1 index ty_inputs env input_vars input_terms
       in
       let t2 =
-        t_app_of_nier m d input2 index ty_inputs env input_vars input_terms
+        t_app_of_nir m d input2 index ty_inputs env input_vars input_terms
       in
       sum t1 t2 env
     | Div { input1; input2 } ->
       let t1 =
-        t_app_of_nier m d input1 index ty_inputs env input_vars input_terms
+        t_app_of_nir m d input1 index ty_inputs env input_vars input_terms
       in
       let t2 =
-        t_app_of_nier m d input2 index ty_inputs env input_vars input_terms
+        t_app_of_nir m d input2 index ty_inputs env input_vars input_terms
       in
       div t1 t2 env
     | ReLu { input } ->
       let t =
-        t_app_of_nier m d input index ty_inputs env input_vars input_terms
+        t_app_of_nir m d input index ty_inputs env input_vars input_terms
       in
       relu env t
     | _ ->
       Logging.not_implemented_yet (fun f ->
         f "ONNX operator %a is not implemented for nn2smt transformation"
-          IR.Node.pp node)
+          Nir.Node.pp node)
   in
   ( ls,
     Why3.Decl.create_logic_decl [ Why3.Decl.make_ls_defn ls input_vars v_term ]
@@ -159,20 +160,18 @@ module MTermL = Why3.Extmap.Make (struct
   type t = T.t list [@@deriving ord]
 end)
 
-let app_terms_of_nier_output m d (nn : Language.nn) env index tl =
+let app_terms_of_nir_output m d (nn : Language.nn) env index tl =
   match nn.nn_format with
   | NNet -> Logging.not_implemented_yet (fun f -> f "NNet to SMT conversion")
   | ONNX None -> Logging.code_error ~src (fun f -> f "No ONNX to convert")
   | ONNX (Some g) ->
     let vtl = List.fold tl ~init:Why3.Term.Mvs.empty ~f:Why3.Term.t_freevars in
-    let m' = ref (MTermL.find_def (Map.empty (module IR.Node)) tl !m) in
+    let m' = ref (MTermL.find_def (Map.empty (module Nir.Node)) tl !m) in
     let t =
       if Why3.Term.Mvs.is_empty vtl
       then
         (* use global constant *)
-        let ls =
-          terms_of_nier m' d (IR.output g) index nn.nn_ty_elt env [] tl
-        in
+        let ls = terms_of_nir m' d (IR.output g) index nn.nn_ty_elt env [] tl in
         Why3.Term.fs_app ls [] nn.nn_ty_elt
       else
         (* use global functions *)
@@ -186,7 +185,7 @@ let app_terms_of_nier_output m d (nn : Language.nn) env index tl =
         in
         let input_terms = List.map input_vars ~f:Why3.Term.t_var in
         let ls =
-          terms_of_nier m' d (IR.output g) index nn.nn_ty_elt env input_vars
+          terms_of_nir m' d (IR.output g) index nn.nn_ty_elt env input_vars
             input_terms
         in
         Why3.Term.fs_app ls tl nn.nn_ty_elt
@@ -224,7 +223,7 @@ let actual_nn_flow env =
       match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with
       | Some ({ nn_nb_inputs; _ } as nn), Some vector_length ->
         assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
-        app_terms_of_nier_output m d nn env
+        app_terms_of_nir_output m d nn env
           (Why3.Number.to_small_integer index)
           tl
       | _, _ ->
diff --git a/src/transformations/nn2smt.mli b/src/transformations/nn2smt.mli
index 6f92b5a3a81702912bf462acecc37ece5385ebe3..0cb0eb9b33800dc7fbb3ffd5ff858c752b60dfa6 100644
--- a/src/transformations/nn2smt.mli
+++ b/src/transformations/nn2smt.mli
@@ -20,14 +20,14 @@
 (*                                                                        *)
 (**************************************************************************)
 
-(** This module converts a valid NIER into WhyML terms.
+(** This module converts a valid NIR into WhyML terms.
 
-    NIER encapsulate parameters in tensor forms and a computation graph with
+    NIR encapsulate parameters in tensor forms and a computation graph with
     various operations. WhyML language supports multidimensional arrays as well.
 
-    This module translates NIER data into a list of WhyML terms, describing the
+    This module translates NIR data into a list of WhyML terms, describing the
     control flow of the neural network. Variables are stored inside of an
-    environment, their shape being either provided by the NIER or inferred with
+    environment, their shape being either provided by the NIR or inferred with
     the expected result of ONNX operations. *)
 
 val trans : Why3.Env.env -> Why3.Task.task Why3.Trans.trans
diff --git a/src/verification.ml b/src/verification.ml
index e63540bffa5fbf6bcdb2ad4b413fcfb0720c1589..96bd03005badf2f6d4825dbd5202534abdb546e8 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -90,27 +90,27 @@ let create_env loadpath =
       (loadpath @ stdlib @ Whyconf.loadpath (Whyconf.get_main config)),
     config )
 
-let write_nier_as_onnx onnx_out_dir =
+let write_nir_as_onnx onnx_out_dir =
   Language.iter_nn (fun ls nn ->
     match nn.nn_format with
-    | ONNX (Some nn_nier) -> (
+    | ONNX (Some nn_nir) -> (
       try
         if not (Stdlib.Sys.file_exists onnx_out_dir)
         then Stdlib.Sys.mkdir onnx_out_dir 0o755;
         let filename =
-          Fmt.str "%s%s%a.nier.onnx" onnx_out_dir Stdlib.Filename.dir_sep
+          Fmt.str "%s%s%a.nir.onnx" onnx_out_dir Stdlib.Filename.dir_sep
             Pretty.print_ls ls
         in
-        Onnx.Simple.write nn_nier filename;
-        Logs.debug ~src:Logging.src_nier (fun m ->
-          m "@[Wrote NIER as ONNX in file '%s'@]" filename)
+        Onnx.Writer.to_file nn_nir filename;
+        Logs.debug ~src:Logging.src_nir (fun m ->
+          m "@[Wrote NIR as ONNX in file '%s'@]" filename)
       with Sys_error msg ->
         Logging.user_error (fun m ->
-          m "@[Cannot write NIER as ONNX in folder '%s': '%s'@]" onnx_out_dir
-            msg))
+          m "@[Cannot write NIR as ONNX in folder '%s': '%s'@]" onnx_out_dir msg)
+      )
     | _ ->
       Logs.warn (fun m ->
-        m "@[No available NIER to write as ONNX for logic symbol '%a'@]"
+        m "@[No available NIR to write as ONNX for logic symbol '%a'@]"
           Pretty.print_ls ls))
 
 let answer_saver_on_dataset limit config env config_prover ~dataset task =
@@ -488,5 +488,5 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
           tasks)
       mstr_theory
   in
-  Option.iter ~f:write_nier_as_onnx onnx_out_dir;
+  Option.iter ~f:write_nir_as_onnx onnx_out_dir;
   verification_result
diff --git a/src/verification.mli b/src/verification.mli
index 935f619fade3bd23f0cf2704bee4150fc0b61551..71d1b7839351e3a3c4c76a7f1be0d78482f87e89 100644
--- a/src/verification.mli
+++ b/src/verification.mli
@@ -76,7 +76,7 @@ val verify :
       is a theory:goals list each identifying only the goals of a theory to
       verify.
     @param onnx_out_dir
-      is the directory in which to write the ONNX files generated from the NIER.
+      is the directory in which to write the ONNX files generated from the NIR.
     @return
       for each theory, an [answer] for each goal of the theory, respecting the
       order of appearance. *)
diff --git a/tests/acasxu.t b/tests/acasxu.t
index 0d19a6594cbffbba1f011a43458c1e182c8ba949..06b8c65a4b99e2268ee7635724e3034f50037634 100644
--- a/tests/acasxu.t
+++ b/tests/acasxu.t
@@ -205,7 +205,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le y (1500.0:t)
@@ -269,14 +269,14 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_9.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   nn_onnx,
   (Interpreter_types.Model
      (Interpreter_types.ONNX, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5;
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le y1 (1500.0:t) /\ le y2 (1500.0:t)
@@ -350,7 +350,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le y3 y4 /\ le y4 y3
@@ -426,7 +426,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le y5 y6 /\ le y6 y5
@@ -504,7 +504,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le y7 y8 /\ le y8 y7
@@ -586,7 +586,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{NN-Gen-Term} new goal:le (0.0:t) y9 /\ le y9 (0.0:t)
diff --git a/tests/acasxu_ci.t b/tests/acasxu_ci.t
index c55a2f7781913e9a97a0e8cc1562040333181606..b9b6cb3c669275ed2e370bf90f2a2fa7fe903e62 100644
--- a/tests/acasxu_ci.t
+++ b/tests/acasxu_ci.t
@@ -166,7 +166,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{ProverSpec} Prover-tailored specification:
@@ -229,7 +229,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector, 5
   nn_onnx1,
   (Interpreter_types.Model
@@ -237,7 +237,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_9.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   [DEBUG]{ProverSpec} Prover-tailored specification:
   0.0 <= x0
   x0 <= 60760.0
@@ -308,7 +308,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{ProverSpec} Prover-tailored specification:
@@ -382,7 +382,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{ProverSpec} Prover-tailored specification:
@@ -498,7 +498,7 @@ Test verify on acasxu
                                 nn_ty_elt = t;
                                 nn_filename =
                                 "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx";
-                                nn_format = <nier> }))
+                                nn_format = <nir> }))
   vector,
   5
   [DEBUG]{ProverSpec} Prover-tailored specification:
diff --git a/tests/nier_to_onnx.t b/tests/nier_to_onnx.t
index d77897c5db91be6e61a94e3b10273908ed802116..9bbd9c34302db2729aa58e6cd47c225b3452ef42 100644
--- a/tests/nier_to_onnx.t
+++ b/tests/nier_to_onnx.t
@@ -4,8 +4,8 @@ Test verify
   $ bin/pyrat.py --version
   PyRAT 1.1
 
-  $ caisar verify --format whyml --prover=PyRAT --ltag=NIER --onnx-out-dir="out" - 2>&1 <<EOF
-  > theory NIER_to_ONNX
+  $ caisar verify --format whyml --prover=PyRAT --ltag=NIR --onnx-out-dir="out" - 2>&1 <<EOF
+  > theory NIR_to_ONNX
   >   use ieee_float.Float64
   >   use caisar.types.Vector
   >   use caisar.model.Model
@@ -20,12 +20,12 @@ Test verify
   >       (0.5:t) .< (nn @@ i)[0] .< (0.5:t)
   > end
   > EOF
-  [DEBUG]{NIER} Wrote NIER as ONNX in file 'out/nn_onnx.nier.onnx'
+  [DEBUG]{NIR} Wrote NIR as ONNX in file 'out/nn_onnx.nir.onnx'
   Goal G: Unknown ()
 
 Data should be 0.135
 
   $ python3 bin/inspect_onnx.py
-  out/nn_onnx.nier.onnx has 1 input nodes
+  out/nn_onnx.nir.onnx has 1 input nodes
   {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}}
   1 files checked
diff --git a/utils/logging.ml b/utils/logging.ml
index afe6c3d5132e8bebfc474208afec04f9ed67bbc6..c3f129f99bef70369aa905a3a3d200b2c3d9b00e 100644
--- a/utils/logging.ml
+++ b/utils/logging.ml
@@ -30,7 +30,7 @@ let src_prover_call = Logs.Src.create "ProverCall" ~doc:"Prover call"
 let src_interpret_goal =
   Logs.Src.create "InterpretGoal" ~doc:"Goal interpretation"
 
-let src_nier = Logs.Src.create "NIER" ~doc:"Neural Intermediate Representation"
+let src_nir = Logs.Src.create "NIR" ~doc:"Neural Intermediate Representation"
 let src_stack_trace = Logs.Src.create "StackTrace" ~doc:"Print stack trace"
 let all_srcs () = Logs.Src.list ()
 
diff --git a/utils/logging.mli b/utils/logging.mli
index 6410b1a5911380629712fbc874a397d3808b97f6..d05c477a90ee8739fc6d08a980bd7a06b68feb3d 100644
--- a/utils/logging.mli
+++ b/utils/logging.mli
@@ -26,7 +26,7 @@ val src_autodetect : Logs.src
 val src_prover_spec : Logs.src
 val src_prover_call : Logs.src
 val src_interpret_goal : Logs.src
-val src_nier : Logs.src
+val src_nir : Logs.src
 val src_stack_trace : Logs.src
 
 val all_srcs : unit -> Logs.src list