diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml deleted file mode 100644 index ad7f5bd5f84df85af6356fcbcd3c28cac917bca2..0000000000000000000000000000000000000000 --- a/lib/ir/nier_cfg.ml +++ /dev/null @@ -1,515 +0,0 @@ -open Base -open Stdio -open Bigarray - -module Tensor = struct - type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t - type shape = int array [@@deriving show] - - type ('a, 'b) t_kind = - | K_int : (int64, int64_elt) t_kind - | K_float : (float, float64_elt) t_kind - - let create : type a b. shape -> (a, b) t_kind -> (a, b) t = - fun shape -> function - | K_float -> Genarray.create float64 c_layout shape - | K_int -> Genarray.create int64 c_layout shape - - let unsqueeze ~sh1 ~sh2 = - let sh1, sh2 = (Array.to_list sh1, Array.to_list sh2) in - let longest, shortest = - match List.length sh1 > List.length sh2 with - | true -> (sh1, sh2) - | false -> (sh2, sh1) - in - (*find the index of the potential additional dimension*) - let where_zero = - match List.nth_exn longest 0 with - | 0 -> Some 0 - | _ -> ( - match List.last_exn longest with - | 0 -> Some (List.length longest - 1) - | _ -> None) - in - match where_zero with - | Some idx -> ( - match List.sub longest ~pos:idx ~len:(List.length shortest) with - | [] -> None - | _ -> Some (Array.of_list longest)) - | None -> None - - let get t idx = Genarray.get t idx - let set t idx v = Genarray.set t idx v - - let all_coords sh = - let sh = Array.to_list sh in - let rec ranges acc shape = - match shape with - | x :: y -> ranges (List.init x ~f:(fun i -> i) :: acc) y - | [] -> acc - (* a list containing a list of all possible indexes, for each dimension *) - in - let xxs = ranges [] sh in - (* add to each element of the list of all possible coordinates all*) - (* * possible indexes ... *) - let aux acc xs = - List.concat - @@ List.map xs ~f:(fun x -> List.map ~f:(fun lt -> x :: lt) acc) - (* ... for each dimension, starting from an empty list of*) - (* * possible coordinates *) - in - List.fold xxs ~init:[ [] ] ~f:aux - - let flatten t = - let shape = Genarray.dims t in - let all_idxs = all_coords shape in - List.init (List.length all_idxs) ~f:(fun i -> - get t (Array.of_list @@ List.nth_exn all_idxs i)) - - let get_shape t = Genarray.dims t - - let equal f t1 t2 = - let t1_sh = get_shape t1 and t2_sh = get_shape t2 in - if Array.equal ( = ) t1_sh t2_sh - then - let all_idxs = all_coords (get_shape t1) in - List.fold - ~f:(fun acc x -> - if acc - then f (get t1 (Array.of_list x)) (get t2 (Array.of_list x)) - else false) - all_idxs ~init:true - else false - - let num_neurons sh = Array.fold ~init:1 ~f:(fun x y -> x * y) sh - - let get_flatnd_idx ~idx ~sh flt = - let sh = Array.to_list sh in - let pop_sh = List.tl_exn sh @ [ 1 ] in - let rec get_factors_from_sh sh_f l = - match sh_f with - | [] -> List.rev l - | _ -> - get_factors_from_sh (List.tl_exn sh_f) - (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l) - in - let factors = get_factors_from_sh pop_sh [] in - let coord_in_data = - match - List.fold2 - ~f:(fun x y z -> x + (y * z)) - ~init:0 (Array.to_list idx) factors - with - | List.Or_unequal_lengths.Ok i -> i - | List.Or_unequal_lengths.Unequal_lengths -> - failwith "Unequal lengths in get_flatnd_idx" - in - List.nth_exn flt coord_in_data - - let transpose_2d _t = assert false -end - -(* TODO: maybe add markers for special nodes, to reflect they are the inputs and - outputs of the neural network? *) -module Node = struct - type shape = int array - - let show_shape sh = - let sh = Array.to_list sh in - match sh with - | [] -> "[]" - | x :: y -> - "[" ^ Int.to_string x - ^ List.fold ~init:"" ~f:(fun str e -> str ^ ";" ^ Int.to_string e) y - ^ "]" - - type operator = - | Add - | Sub - | Mul - | Div - | Matmul - | Gemm - | LogSoftmax - | ReLu - | Transpose - | Squeeze - | MaxPool - | Conv - | Reshape - | Flatten - | Identity - | Constant - | NO_OP - | RW_Linearized_ReLu - | Gather - | ReduceSum - | GatherND - | RandomNormal - | Abs - | Log - - let str_op o = - match o with - | Add -> "Add" - | Sub -> "Sub" - | Mul -> "Mul" - | Div -> "Div" - | Matmul -> "Matmul" - | Gemm -> "Gemm" - | LogSoftmax -> "LogSoftmax" - | ReLu -> "ReLu" - | Transpose -> "Transpose" - | Squeeze -> "Squeeze" - | MaxPool -> "MaxPool" - | Conv -> "Conv" - | Reshape -> "Reshape" - | Flatten -> "Flatten" - | Identity -> "Identity" - | Constant -> "Constant" - | NO_OP -> "NO_OP" - | RW_Linearized_ReLu -> "RW_Linearized_ReLu" - | Gather -> "Gather" - | ReduceSum -> "ReduceSum" - | GatherND -> "GatherND" - | RandomNormal -> "RandomNormal" - | Abs -> "Abs" - | Log -> "Log" - - type ksize = Ksize of shape - type stride = Stride of shape - type pads = Pads of shape - type dilations = Dilations of shape - - type operator_parameters = - | Pool_params of (ksize * stride option * pads option * dilations option) - | Conv_params of (ksize * stride option * pads option * dilations option) - | Transpose_params of shape - | RW_Linearized_ReLu_params of - (bool list list * ((string, float) Base.Hashtbl.t list * int)) - | Gather_params of int - | ReduceSum_params of int * int - | RandomNormal_params of int * float * float * float * shape - - let str_op_params p = - match p with - | Transpose_params s -> - let str_sh = show_shape s in - "Transpose params: " ^ str_sh - | Pool_params (Ksize k, s, p, d) | Conv_params (Ksize k, s, p, d) -> - let str_k = show_shape k - and str_s = match s with None -> "" | Some (Stride ss) -> show_shape ss - and str_p = match p with None -> "" | Some (Pads pp) -> show_shape pp - and str_d = - match d with None -> "" | Some (Dilations dd) -> show_shape dd - in - "Pool params: KSIZE: " ^ str_k ^ ", Pads: " ^ str_p ^ ", Stride: " ^ str_s - ^ ", Dilations: " ^ str_d - | RW_Linearized_ReLu_params l -> - (* Only displays the activation scheme on the ReLU node *) - let activations = fst l in - let act' = - List.map - ~f:(fun l1 -> - List.map - ~f:(fun b -> match b with true -> "true" | false -> "false") - l1) - activations - in - let act'' = - List.map ~f:(fun l -> "[" ^ String.concat ~sep:";" l ^ "]") act' - in - let act''' = "[" ^ String.concat ~sep:";" act'' ^ "]" in - "RW_Linearized_ReLu_params: " ^ act''' - | Gather_params a -> "Gather params: " ^ Int.to_string a - | ReduceSum_params (a, b) -> - "ReduceSum params: " ^ Int.to_string a ^ Int.to_string b - | RandomNormal_params (a, b, c, d, s) -> - let str_sh = show_shape s in - "RandomNormal params: " ^ Int.to_string a ^ Float.to_string b - ^ Float.to_string c ^ Float.to_string d ^ str_sh - - type ('a, 'b) t = { - id : int; - name : string option; - shape : shape; - operator : operator; - operator_parameters : operator_parameters option; - pred : string list; - succ : string list; - tensor : ('a, 'b) Tensor.t option; - } - - let compare v1 v2 = Stdlib.compare v1.id v2.id - let hash (v : ('a, 'b) t) = v.id - let equal v1 v2 = v1.id = v2.id - - let create ~id ~name ~sh ~op ~op_p ~pred ~succ ~tensor = - { - id; - name; - shape = sh; - operator = op; - operator_parameters = op_p; - pred; - succ; - tensor; - } - - let get_name t = match t.name with Some n -> n | None -> "C_NODE" - let get_shape t = t.shape - let get_op t = t.operator - let get_tensor t = t.tensor - let get_pred_list t = t.pred - let get_succ_list t = t.succ - let is_data_node t = match get_tensor t with None -> false | Some _ -> true - - (* TODO: some flags on the node would be cleaner than this*) - let is_input_node t = List.equal String.equal t.pred [ "NO_INPUT" ] - let is_output_node t = List.equal String.equal t.succ [ "NO_OUTPUT" ] - - let num_neurons t = - match get_shape t with - | [||] -> 0 - | l -> Array.fold ~init:1 ~f:(fun x acc -> x * acc) l - - let show n f = - let id = Int.to_string n.id in - let name = get_name n - and operator = str_op n.operator - and operator_parameters = - match n.operator_parameters with - | Some p -> str_op_params p - | None -> "no parameters" - and shape = show_shape n.shape - and prevs = - List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_pred_list n) - and nexts = - List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_succ_list n) - and tensor = - match n.tensor with - (*limit of size for tensor strings, complying with - * dot string size limit of 16Ko *) - | Some t -> - let display_indices = - let all_indices = Tensor.all_coords (Tensor.get_shape t) in - if List.length all_indices > 10 - then - let rec firstk k xs = - match xs with - | [] -> failwith "firstk" - | x :: xs -> if k = 1 then [ x ] else x :: firstk (k - 1) xs - in - firstk 10 all_indices - else all_indices - in - let t_value_string f = - List.fold_left - ~f:(fun acc l -> - acc - ^ show_shape (Array.of_list l) - ^ ": " - ^ f (Tensor.get t (Array.of_list l)) - ^ "\n") - ~init:"" display_indices - in - "Tensor value\n: " ^ t_value_string f ^ "\nShape: " - ^ show_shape (Tensor.get_shape t) - | None -> "No tensor in node" - in - "ID :" ^ id ^ "\nNAME: " ^ name ^ "\nOP: " ^ operator ^ "\nOP PARAMS:" - ^ operator_parameters ^ "\nSHAPE: " ^ shape ^ "\nPREVS: " ^ prevs - ^ "\nNEXTS: " ^ nexts ^ "\nTENSORS INFOS:" ^ tensor -end - -module type VInput = sig - type l - type r - - val convert_f : l -> string -end - -module MakeVertex (I : VInput) = struct - type t = (I.l, I.r) Node.t - - let compare = Node.compare - let hash = Node.hash - let equal = Node.equal - let convert_f = I.convert_f - - type label = string - - let label (n : t) = match n.Node.name with Some n -> n | None -> "" - let create _name = assert false -end - -module Edge = struct - type t = string - - let compare = Stdlib.compare - let equal = phys_equal - let default = "" -end - -module NierCFG (I : VInput) = struct - module Vertex = MakeVertex (I) - include Graph.Imperative.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge) - - let convert_f = Vertex.convert_f - let vertex_list g = fold_vertex (fun x l -> x :: l) g [] - - let input_nodes g = - let input_criterion (v : ('a, 'b) Node.t) acc = - match v.id with 0 -> Some v | _ -> acc - in - match fold_vertex (fun v acc -> input_criterion v acc) g None with - | Some r -> [ r ] - | None -> failwith "Something strange, no node for describing inputs found" - - let preds g v = pred g v - - let preds_names g v = - let preds_list = pred_e g v in - List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) preds_list - - let succs_names g v = - let succs_list = succ_e g v in - List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) succs_list - - let succs g v = succ g v - let init_cfg = create () - - let find_vertices g f = - fold_vertex (fun x l -> if f x then x :: l else l) g [] - - let data_node_of n g = - fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None - - let infer_shape g n in_shape ~on_backward = - let op = Node.get_op n in - match op with - | Node.Add -> ( - match data_node_of n g with - | Some d_n -> Node.get_shape d_n - | None -> failwith "Error, Add operator lacks a data node") - | Node.ReLu -> in_shape - | Node.Matmul -> - let pad_left = function - | [] -> failwith "Impossible to pad empty shape" - | [ a ] -> [ 1; a ] - | x -> x - in - let pad_right = function - | [] -> failwith "Impossible to pad empty shape" - | [ a ] -> [ a; 1 ] - | x -> x - in - let rec one_padding l i = - if i <= 0 then l else one_padding (1 :: l) (i - 1) - in - let dn_shape = - match data_node_of n g with - | Some dn -> Node.get_shape dn - | None -> failwith "Error, MatMul operator lacks a data node" - in - (* Expected semantic: - * Matrix multiplication C = AB - * A (shape [n;m]); B (shape [m;p]); C (shape [n;p]) - * shape of b: b_sh - * shape of a: a_sh - * shape of c: c_sh - * It is expected here that B is the shape of the node - * yielding the data tensor in the NIER - *) - let check_matmul_size_ba ~b_sh ~a_sh = - let bdim2 = pad_left b_sh in - let adim2 = pad_right a_sh in - let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in - let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in - let rec infer_csize acc ad bd = - match (ad, bd) with - | [ m; n ], [ nn; p ] -> - if nn = n - then (n, List.append (List.rev acc) [ m; p ]) - else failwith "size of matrices not adequate" - | a :: la, b :: lb -> - if a = b - then infer_csize (a :: acc) la lb - else if a = 1 - then infer_csize (b :: acc) la lb - else if b = 1 - then infer_csize (a :: acc) la lb - else failwith "Checking matmul_size failed: one discordance" - | _, _ -> failwith "Checking matmul_size failed" - in - infer_csize [] bdim adim - in - let check_matmul_size_bc ~b_sh ~c_sh = - let bdim2 = pad_left b_sh in - let cdim2 = pad_right c_sh in - let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in - let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in - let rec infer_asize acc bd cd = - match (bd, cd) with - | [ m; p ], [ n; pp ] -> - if pp = p - then (n, List.append (List.rev acc) [ n; m ]) - else failwith "size of matrices not adequate" - | b :: lb, c :: lc -> - if b = c - then infer_asize (b :: acc) lb lc - else if b = 1 - then infer_asize (b :: acc) lb lc - else if c = 1 - then infer_asize (c :: acc) lb lc - else failwith "Checking matmul_size failed: one discordance" - | _, _ -> failwith "Checking matmul_size failed" - in - infer_asize [] bdim cdim - in - if on_backward - then - Array.of_list - @@ snd - (check_matmul_size_bc ~b_sh:(Array.to_list dn_shape) - ~c_sh:(Array.to_list in_shape)) - else - Array.of_list - @@ snd - (check_matmul_size_ba ~b_sh:(Array.to_list in_shape) - ~a_sh:(Array.to_list dn_shape)) - | a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a)) -end - -module NierCFGInt = NierCFG (struct - type l = int64 - type r = int64_elt - - let convert_f = Int64.to_string -end) - -module NierCFGFloat = NierCFG (struct - type l = float - type r = float64_elt - - let convert_f = Float.to_string -end) - -module NierCFGDot = Graph.Graphviz.Dot (struct - include NierCFGFloat (* use the graph module from above *) - - let node_label (v : vertex) = Node.show v convert_f - let edge_attributes (_, e, _) = [ `Label e; `Color 4711 ] - let default_edge_attributes _ = [] - let get_subgraph _ = None - let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ] - let vertex_name (v : vertex) = Int.to_string v.id - let default_vertex_attributes _ = [] - let graph_attributes _ = [] -end) - -let print_cfg_graph g = NierCFGDot.fprint_graph Stdlib.Format.std_formatter g - -let out_cfg_graph g = - let file = Out_channel.create "cfg.dot" in - NierCFGDot.output_graph file g diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli deleted file mode 100644 index acacd98ac6781091b86950638de57e4d9f8b1bc1..0000000000000000000000000000000000000000 --- a/lib/ir/nier_cfg.mli +++ /dev/null @@ -1,283 +0,0 @@ -(** This module defines the structure and interfaces for a Neural IntermediatE - Representation (NIER). - - It is primarly designed as an intermediate state into producing verifiable - terms from an ONNX model. *) - -open Base -open Bigarray - -(** {1 Tensor module} *) - -(** Tensors are multidimensional arrays used to represent numerical such as a - neural network weight *) - -module Tensor : sig - type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t - type shape = int array [@@deriving show] - - val all_coords : shape -> int list list - - (** [create sh] initialize a tensor with the given shape [sh] with a default - value, depending of the type of the tensor*) - - type ('a, 'b) t_kind = - | K_int : (int64, int64_elt) t_kind - | K_float : (float, float64_elt) t_kind - - val create : shape -> ('a, 'b) t_kind -> ('a, 'b) t - - (** [get t idx] returns the value in tensor [t] stored at coordinates [idx]. - Throw an error if the coordinate is invalid.*) - - val get : ('a, 'b) t -> shape -> 'a - - (** [set_idx t idx v] sets value [v] for tensor [t] at [idx]. Throw an error - if the coordinate is invalid.*) - - val set : ('a, 'b) t -> shape -> 'a -> unit - - (** [equal f t1 t2] applies [f] to all values of [t1] and [t2], and returns - true if all applications of f returned true. *) - - val equal : ('a -> 'a -> bool) -> ('a, 'b) t -> ('a, 'b) t -> bool - - (** [get_shape t] returns the shape of [t]. *) - - val get_shape : ('a, 'b) t -> shape - - (** [flatten t] returns a flattened version of [t]. *) - - val flatten : ('a, 'b) t -> 'a list - - (** [num_neurons sh] returns the total number of neurons given a shape *) - - val num_neurons : shape -> int - - (** [get flatnd_idx idx sh flt] returns the value that would be stored at - index [idx] under a tensor of shape [sh], given the flattened version of - this tensor [flt].*) - - val get_flatnd_idx : idx:shape -> sh:shape -> 'a list -> 'a - - (** [transpose_2d t] returns a copy of the tensor [t] with its two last - dimension exchanged.*) - - val transpose_2d : ('a, 'b) t -> ('a, 'b) t - - (** [unsqueeze sh1 sh2] returns the lowest common shape between [sh1] and - [sh2], and None if there is no common shape. A common shape is when a - shape of higher dimension has only 1 coordinates on non-shared dimensions - with the other. *) - - val unsqueeze : sh1:shape -> sh2:shape -> shape option -end - -(** {1 Modules for graph generation} *) - -module Node : sig - type shape = int array - - type operator = - | Add - | Sub - | Mul - | Div - | Matmul - | Gemm - | LogSoftmax - | ReLu - | Transpose - | Squeeze - | MaxPool - | Conv - | Reshape - | Flatten - | Identity - | Constant - | NO_OP - | RW_Linearized_ReLu - | Gather - | ReduceSum - | GatherND - | RandomNormal - | Abs - | Log - - (** Type describing the different operations handled. Those operations are - inspired by those defined in the ONNX documentation. - - @see <https://github.com/onnx/onnx/blob/master/docs/Operators.md> - for more informations. They are to be coupled with the relevant - operators parameters. *) - - val str_op : operator -> string - val show_shape : shape -> string - - type ksize = Ksize of shape - type stride = Stride of shape - type pads = Pads of shape - type dilations = Dilations of shape - - type operator_parameters = - | Pool_params of (ksize * stride option * pads option * dilations option) - | Conv_params of (ksize * stride option * pads option * dilations option) - | Transpose_params of shape - | RW_Linearized_ReLu_params of - (bool list list * ((string, float) Base.Hashtbl.t list * int)) - | Gather_params of int - | ReduceSum_params of int * int - | RandomNormal_params of int * float * float * float * shape - - val str_op_params : operator_parameters -> string - - type ('a, 'b) t = { - id : int; - name : string option; - shape : shape; - operator : operator; - operator_parameters : operator_parameters option; - pred : string list; - succ : string list; - tensor : ('a, 'b) Tensor.t option; - } - (** Type encapsulating parameters for operations. For Convolutions and - Pooling, kernel size, padding, strides For Transpose, shape *) - - val compare : ('a, 'b) t -> ('a, 'b) t -> int - val hash : ('a, 'b) t -> int - val equal : ('a, 'b) t -> ('a, 'b) t -> bool - - val create : - id:int -> - name:string option -> - sh:shape -> - op:operator -> - op_p:operator_parameters option -> - pred:string list -> - succ:string list -> - tensor:('a, 'b) Tensor.t option -> - ('a, 'b) t - - val get_name : ('a, 'b) t -> string - val get_shape : ('a, 'b) t -> shape - val get_op : ('a, 'b) t -> operator - val get_pred_list : ('a, 'b) t -> string list - val get_succ_list : ('a, 'b) t -> string list - val get_tensor : ('a, 'b) t -> ('a, 'b) Tensor.t option - val is_data_node : ('a, 'b) t -> bool - val is_input_node : ('a, 'b) t -> bool - val is_output_node : ('a, 'b) t -> bool - val num_neurons : ('a, 'b) t -> int - val show : ('a, 'b) t -> ('a -> string) -> string -end - -module type VInput = sig - type l - type r - - val convert_f : l -> string -end - -module MakeVertex (I : VInput) : sig - include Graph.Sig.VERTEX with type t = (I.l, I.r) Node.t -end - -module Edge : sig - type t = string - - val compare : 'a -> 'a -> int - val equal : 'a -> 'a -> bool - val default : t -end - -(** NIER is a graph {b (V,E)} where {b V} is the set of vertices (nodes) and - {b E} is the set of edges (connections between nodes). Nodes contains the - following informations: - - - unique id - - name coming from the original model, if it exists - - shape of the tensor resulting from the application of the node operation, - if it exist - - operation performed - - parameters of the operation - - an optional tensor storing the data - - Note that tensor have their own shape; they must be equal to the NIER's node - shape however. *) - -module NierCFG (I : VInput) : sig - include - Graph.Sig.I - with type V.t = MakeVertex(I).t - and type V.label = MakeVertex(I).t - and type E.t = MakeVertex(I).t * Edge.t * MakeVertex(I).t - and type E.label = Edge.t - - val init_cfg : t - val vertex_list : t -> vertex list - val preds : t -> vertex -> vertex list - - (** [preds_names g v] returns a list of names of predecessors nodes *) - - val preds_names : t -> vertex -> string list - val succs : t -> vertex -> vertex list - - (** [succs_names g v] returns a list of names of predecessors nodes *) - - val succs_names : t -> vertex -> string list - - (** [input_node g] returns the nodes considered as describing the inputs of - the neural network. *) - - val input_nodes : t -> vertex list - val find_vertices : t -> (vertex -> bool) -> vertex list -end - -module NierCFGFloat : sig - include - Graph.Sig.I - with type V.t = (float, Bigarray.float64_elt) Node.t - and type V.label = (float, Bigarray.float64_elt) Node.t - and type E.t = - (float, Bigarray.float64_elt) Node.t - * Edge.t - * (float, Bigarray.float64_elt) Node.t - and type E.label = Edge.t - - val init_cfg : t - val vertex_list : t -> vertex list - val preds : t -> vertex -> vertex list - - (** [preds_names g v] returns a list of names of predecessors nodes *) - - val preds_names : t -> vertex -> string list - val succs : t -> vertex -> vertex list - - (** [succs_names g v] returns a list of names of predecessors nodes *) - - val succs_names : t -> vertex -> string list - - (** [input_node g] returns the nodes considered as describing the inputs of - the neural network. *) - - val input_nodes : t -> vertex list - val find_vertices : t -> (vertex -> bool) -> vertex list - - (** [data_node_of n ] returns one node containing a tensor * data among the - predecessors of [n]*) - - val data_node_of : vertex -> t -> vertex option - - (** [infer_shape g n sh o_b] returns the inferred shape of the output of node - [n] in NIER [g] with input shape [sh]. Shape inference is made using the - node operator and its predecessors shapes. [o_b] is true when performing - backward propagation, to choose which matrix size to consider. *) - - val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape -end - -(** {1 Pretty printers} *) - -val print_cfg_graph : NierCFGFloat.t -> unit -val out_cfg_graph : NierCFGFloat.t -> unit diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml deleted file mode 100644 index c179160172e25e44e97ef583e3918a98324f7aff..0000000000000000000000000000000000000000 --- a/lib/ir/nier_simple.ml +++ /dev/null @@ -1,695 +0,0 @@ -open Base - -module Shape : sig - type t [@@deriving show, ord, eq] - - val to_array : t -> int array - val of_array : int array -> t - val to_list : t -> int list - val of_list : int list -> t - val rank : t -> int - val size : t -> int - val get : t -> int -> int - val set : t -> int -> int -> t - val to_array_unsafe : t -> int array - val row_major : t -> int array -> int - val unrow_major : t -> int -> int array -end = struct - type t = int array [@@deriving ord, eq] - - let to_array = Array.copy - let to_array_unsafe x = x - let of_array = Array.copy - let to_list = Array.to_list - let of_list = Array.of_list - let rank = Array.length - let size t = Array.fold t ~f:( * ) ~init:1 - let pp fmt x = Fmt.pf fmt "[%a]" Fmt.(array ~sep:semi int) x - let show s = Fmt.str "%a" pp s - let get = Array.get - - let set t k v = - let t = Array.copy t in - Array.set t k v; - t - - let row_major t a = - assert (Array.length t = Array.length a); - let r = ref 0 in - for i = 0 to Array.length t - 1 do - r := (!r * t.(i)) + a.(i) - done; - !r - - let unrow_major t i = - let r = ref i in - let a = Array.create ~len:(Array.length t) 0 in - for i = Array.length t - 1 downto 0 do - a.(i) <- !r % t.(i); - r := !r / t.(i) - done; - a -end - -module Tensor : sig - type ('a, 'b) t - - val of_tensor : ('a, 'b) Nier_cfg.Tensor.t -> ('a, 'b) t - val to_tensor : ('a, 'b) t -> ('a, 'b) Nier_cfg.Tensor.t - val create_1_float : float -> (float, Bigarray.float64_elt) t - val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t - val shape : ('a, 'b) t -> Shape.t - val flatten : ('a, 'b) t -> 'a list - - val of_array1 : - Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t - - val get : ('a, 'b) t -> int array -> 'a -end = struct - type ('a, 'b) t = ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t - - let copy t = - let t' = Bigarray.Genarray.(create (kind t) Bigarray.c_layout (dims t)) in - Bigarray.Genarray.blit t t'; - t' - - let of_tensor = copy - let to_tensor = copy - - let create_1_float v = - let t = - Bigarray.Genarray.(create Bigarray.float64 Bigarray.c_layout [| 1 |]) - in - Bigarray.Genarray.set t [| 0 |] v; - t - - let create_1_int64 v = - let t = - Bigarray.Genarray.(create Bigarray.int64 Bigarray.c_layout [| 1 |]) - in - Bigarray.Genarray.set t [| 0 |] v; - t - - let shape x = Shape.of_array @@ Nier_cfg.Tensor.get_shape x - let flatten = Nier_cfg.Tensor.flatten - - let of_array1 shape t = - Bigarray.reshape - (copy @@ Bigarray.genarray_of_array1 t) - (Shape.to_array_unsafe shape) - - let get = Bigarray.Genarray.get -end - -module GenTensor = struct - type t = - | Float of (float, Bigarray.float64_elt) Tensor.t - | Int64 of (int64, Bigarray.int64_elt) Tensor.t - - let create_1_float f = Float (Tensor.create_1_float f) - let create_1_int64 i = Int64 (Tensor.create_1_int64 i) - - let of_int_array a = - Int64 - (Tensor.of_array1 - (Shape.of_array [| Array.length a |]) - (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int))) - - let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i -end - -(** TODO: add the information needed to compute the shape *) -type descr = - | Constant of { data : GenTensor.t } - | Add of { - input1 : node; - input2 : node; - } - | Sub of { - input1 : node; - input2 : node; - } - | Mul of { - input1 : node; - input2 : node; - } - | Div of { - input1 : node; - input2 : node; - } - | Matmul of { - input1 : node; - input2 : node; - } - | Gemm of { - inputA : node; - inputB : node; - inputC : node option; - alpha : float; - beta : float; - transA : bool; - transB : bool; - } - | LogSoftmax - | ReLu of { input : node } - | Transpose of { - input : node; - (* called "data" in ONNX documentation : - https://onnx.ai/onnx/operators/onnx__Transpose.html*) - perm : int list; - } - | Squeeze of { - data : node; - axes : node option; (* data int64 *) - } - | MaxPool - | Conv - | Reshape of { - input : node; - shape : node; (* data int64 *) - } - | Flatten of { - input : node; - axis : int; - } - | Identity of { input : node } - | Input of { shape : Shape.t } - | RW_Linearized_ReLu - | Concat of { - inputs : node list; - axis : int; - } - | Gather of { - input : node; - indices : node; - axis : int; - } - | ReduceSum of { - input : node; - axes : node option; - keepdims : int; - noop_with_empty_axes : int; - } - | GatherND of { - data : node; - indices : node; - batch_dims : int; - } - | RandomNormal of { - dtype : int; - mean : float; - scale : float; - seed : float; - shape : int array; - } - | Abs of { input : node } - | Log of { input : node } - -and node = { - id : int; - descr : descr; - shape : Shape.t; - ty : ty; -} - -and ty = - | Float - | Int64 - -let pp_descr fmt p = - match p with - | Input { shape } -> Fmt.pf fmt "Input: %a" Shape.pp shape - | Transpose { perm; _ } -> - Fmt.pf fmt "Transpose: [%a]" Fmt.(list ~sep:semi int) perm - | Constant { data = Int64 b } when Shape.size (Tensor.shape b) < 3 -> - Fmt.pf fmt "Constant[%a]" Fmt.(list ~sep:comma int64) (Tensor.flatten b) - | Constant _ -> Fmt.pf fmt "Constant" - | Add _ -> Fmt.pf fmt "Add" - | Sub _ -> Fmt.pf fmt "Sub" - | Mul _ -> Fmt.pf fmt "Mul" - | Div _ -> Fmt.pf fmt "Div" - | Matmul _ -> Fmt.pf fmt "Matmul" - | Gemm _ -> Fmt.pf fmt "Gemm" - | LogSoftmax -> Fmt.pf fmt "LogSoftmax" - | ReLu _ -> Fmt.pf fmt "ReLu" - | Squeeze _ -> Fmt.pf fmt "Squeeze" - | MaxPool -> Fmt.pf fmt "MaxPool" - | Conv -> Fmt.pf fmt "Conv" - | Reshape _ -> Fmt.pf fmt "Reshape" - | Flatten _ -> Fmt.pf fmt "Flatten" - | Identity _ -> Fmt.pf fmt "Identity" - | RW_Linearized_ReLu -> Fmt.pf fmt "RW_Linearized_ReLu" - | Concat { axis; _ } -> Fmt.pf fmt "Concat[%i]" axis - | Gather _ -> Fmt.pf fmt "Gather" - | ReduceSum _ -> Fmt.pf fmt "ReduceSum" - | GatherND _ -> Fmt.pf fmt "GatherND" - | RandomNormal _ -> Fmt.pf fmt "RandomNormal" - | Abs _ -> Fmt.pf fmt "Abs" - | Log _ -> Fmt.pf fmt "Log" - -type t = { - output : node; - succs : (int, node list) Base.Hashtbl.t; -} - -let output t = t.output - -module Node = struct - type t = node - - let compare { id = id1; _ } { id = id2; _ } = Int.compare id1 id2 - let equal { id = id1; _ } { id = id2; _ } = Int.equal id1 id2 - let hash { id; _ } = id - let sexp_of_t node = Base.Int.sexp_of_t node.id - let pp fmt n = Fmt.pf fmt "@[%i: %a@]" n.id pp_descr n.descr - let show n = Fmt.str "%a" pp n - - include Comparator.Make (struct - type nonrec t = t - - let compare = compare - let sexp_of_t = sexp_of_t - end) - - let rec compute_shape n = n.shape - - and compute_shape_descr = function - | Add { input1; _ } - | Div { input1; _ } - | Mul { input1; _ } - | Sub { input1; _ } -> - compute_shape input1 - | Flatten { input; axis } -> - (* (d_0 X d_1 … d_(axis-1), d_axis X d_(axis+1) … X dn). *) - let shape = compute_shape input in - let d1 = ref 1 in - let d2 = ref 1 in - for i = 0 to axis - 1 do - d1 := !d1 * Shape.get shape i - done; - for i = axis to Shape.rank shape - 1 do - d2 := !d1 * Shape.get shape i - done; - Shape.of_list [ !d1; !d2 ] - | Input { shape } -> shape - | ReLu { input } -> compute_shape input - | Transpose { input; perm = [] } -> - compute_shape input |> Shape.to_list |> List.rev |> Shape.of_list - | Transpose { input; perm } -> - let shape = compute_shape input in - let rank = Shape.rank shape in - assert (Int.equal rank (List.length perm)); - let shape' = Array.create ~len:rank 0 in - let shape = Shape.to_array shape in - Base.List.iteri perm ~f:(fun i j -> - Array.set shape' i (Array.get shape j)); - Shape.of_array shape' - | Constant { data } -> GenTensor.shape data - | Concat { inputs; axis } -> - let shapes = List.map ~f:compute_shape inputs in - let shape = List.hd_exn shapes in - let axis = if axis < 0 then Shape.rank shape + axis else axis in - let l = List.map ~f:(fun s -> Shape.get s axis) shapes in - let i = List.reduce_exn ~f:( + ) l in - Shape.set shape axis i - | Gather { input; indices; axis } -> ( - let input_shape = compute_shape input in - let indices_shape = compute_shape indices in - let axis = if axis < 0 then Shape.rank input_shape + axis else axis in - match List.split_n (Shape.to_list input_shape) axis with - | _, [] -> failwith "axis is bigger than shape rank" - | before, _ :: after -> - Shape.of_list (before @ Shape.to_list indices_shape @ after)) - | Matmul { input1; input2 } -> - let pad_left = function - | [] -> failwith "Impossible to pad empty shape" - | [ a ] -> [ 1; a ] - | x -> x - in - let pad_right = function - | [] -> failwith "Impossible to pad empty shape" - | [ a ] -> [ a; 1 ] - | x -> x - in - let rec one_padding l i = - if i <= 0 then l else one_padding (1 :: l) (i - 1) - in - (* Expected semantic: - * Matrix multiplication C = AB - * A (shape [n;m]); B (shape [m;p]); C (shape [n;p]) - * shape of b: b_sh - * shape of a: a_sh - * shape of c: c_sh - *) - let check_matmul_size_ab ~a_sh ~b_sh = - let adim2 = pad_left a_sh in - let bdim2 = pad_right b_sh in - let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in - let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in - let rec infer_csize acc ad bd = - match (ad, bd) with - | [ m; n ], [ nn; p ] -> - if nn = n - then List.rev_append acc [ m; p ] - else failwith "size of matrices not adequate" - | a :: la, b :: lb -> - if a = b - then infer_csize (a :: acc) la lb - else if a = 1 - then infer_csize (b :: acc) la lb - else if b = 1 - then infer_csize (a :: acc) la lb - else failwith "Checking matmul_size failed: one discordance" - | _, _ -> failwith "Checking matmul_size failed" - in - infer_csize [] adim bdim - in - (* TODO: in case of pad_left and pad_right remove the added dimension. But - it is not clear what must be done when broadcasting is done *) - Shape.of_list - (check_matmul_size_ab - ~a_sh:(Shape.to_list (compute_shape input1)) - ~b_sh:(Shape.to_list (compute_shape input2))) - | Reshape { input; shape; _ } -> - let shape = - match shape.descr with - | Constant { data = Int64 a } -> - List.map ~f:Int64.to_int_exn (Tensor.flatten a) - | _ -> - (* Some constant propagation could be useful in some cases eg. patch-1 - VNNcomp *) - failwith "non-constant shape in reshape not supported" - in - List.iter shape ~f:(function - | -1 | 0 -> failwith "not implemented 0 -1 in shape for reshape" - | _ -> ()); - let out = Shape.of_list shape in - if Shape.size out <> Shape.size input.shape - then - failwith - "Reshape: shape of input and shape given have not the same number of \ - elements"; - out - | Gemm { inputA; inputB; inputC = _; alpha = _; beta = _; transA; transB } - -> - let rank2 i = - match Shape.to_array_unsafe i.shape with - | [| k; n |] -> (k, n) - | _ -> failwith "Gemm input must be of size 2" - in - let tr trans (k, n) = if trans then (n, k) else (k, n) in - let a1, a2 = tr transA @@ rank2 inputA in - let b1, b2 = tr transB @@ rank2 inputB in - if not (Int.equal a2 b1) - then - Fmt.failwith "Gemm (M:%i,K:%i) (K:%i,N:%i) -> (M:%i,N:%i)" a1 a2 b1 b2 - a1 b2; - Shape.of_array [| a1; b2 |] - | ( LogSoftmax | Squeeze _ | MaxPool | Conv | Identity _ - | RW_Linearized_ReLu | ReduceSum _ | GatherND _ | RandomNormal _ | Abs _ - | Log _ ) as n -> - failwith (Fmt.str "todo compute shape : %a" pp_descr n) - - let compute_ty : _ -> ty = function - | Constant { data = Float _ } -> Float - | Constant { data = Int64 _ } -> Int64 - | _ -> Float - - let create = - let c = ref (-1) in - fun descr -> - Int.incr c; - { - id = !c; - descr; - shape = compute_shape_descr descr; - ty = compute_ty descr; - } - - let constant_int_array a = - create (Constant { data = GenTensor.of_int_array a }) - - let reshape shape node = - if Shape.equal node.shape shape - then node - else - create - (Reshape - { input = node; shape = constant_int_array (Shape.to_array shape) }) - - let gather_int_as_matmul input i = - let input1 = - reshape (Shape.of_array [| 1; Shape.size input.shape |]) input - in - let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in - Array.set selector i Float.one; - let selector = - GenTensor.Float - (Tensor.of_array1 - (Shape.of_array [| Array.length selector; 1 |]) - (Bigarray.Array1.of_array Float64 C_layout selector)) - in - let input2 = create (Constant { data = selector }) in - let result = create (Matmul { input1; input2 }) in - reshape (Shape.of_array [| 1 |]) result - - let gather_int ?(encode = true) input i = - if encode - then gather_int_as_matmul input i - else - let indices = - create (Constant { data = GenTensor.create_1_int64 (Int64.of_int i) }) - in - create (Gather { input; indices; axis = 0 }) - - let mul_float input f = - let input1 = reshape (Shape.of_array [| 1; 1 |]) input in - let f = Array.create ~len:1 f in - let f = - GenTensor.Float - (Tensor.of_array1 - (Shape.of_array [| Array.length f; 1 |]) - (Bigarray.Array1.of_array Float64 C_layout f)) - in - let input2 = create (Constant { data = f }) in - let result = create (Matmul { input1; input2 }) in - reshape (Shape.of_array [| 1 |]) result - - let div_float ?(encode = true) input f = - if encode - then - let f = Float.one /. f in - mul_float input f - else - let input1 = reshape (Shape.of_array [| 1; 1 |]) input in - let f = Array.create ~len:1 f in - let f = - GenTensor.Float - (Tensor.of_array1 - (Shape.of_array [| Array.length f; 1 |]) - (Bigarray.Array1.of_array Float64 C_layout f)) - in - let input2 = create (Constant { data = f }) in - let result = create (Div { input1; input2 }) in - reshape (Shape.of_array [| 1 |]) result - - let concat_0 = function - | [ n ] -> n - | [] -> failwith "empty concat" - | inputs -> create (Concat { inputs; axis = 0 }) - - let preds node = - match node.descr with - | Constant _ | Input _ -> [] - | Add { input1; input2 } - | Sub { input1; input2 } - | Mul { input1; input2 } - | Div { input1; input2 } - | Matmul { input1; input2 } -> - [ input1; input2 ] - | Gather { input; indices; axis = _ } -> [ input; indices ] - | GatherND { data; indices; batch_dims = _ } -> [ data; indices ] - | ReLu { input } | Abs { input } | Log { input } -> [ input ] - | Concat { inputs; axis = _ } -> inputs - | ReduceSum { input; axes = Some x; _ } -> [ input; x ] - | ReduceSum { input; axes = None; _ } -> [ input ] - | RandomNormal _ -> [] - | Transpose { input; _ } -> [ input ] - | Flatten { input; _ } -> [ input ] - | Identity { input } -> [ input ] - | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ] - | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ] - | Squeeze { data; _ } -> [ data ] - | Reshape { input; shape; _ } -> [ input; shape ] - | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> [] - - let map f n = - match n.descr with - | Constant _ | Input _ -> n - | Add { input1; input2 } -> - create (Add { input1 = f input1; input2 = f input2 }) - | Sub { input1; input2 } -> - create (Sub { input1 = f input1; input2 = f input2 }) - | Mul { input1; input2 } -> - create (Mul { input1 = f input1; input2 = f input2 }) - | Div { input1; input2 } -> - create (Div { input1 = f input1; input2 = f input2 }) - | Matmul { input1; input2 } -> - create (Matmul { input1 = f input1; input2 = f input2 }) - | ReLu { input } -> create (ReLu { input = f input }) - | Abs { input } -> create (Abs { input = f input }) - | Log { input } -> create (Log { input = f input }) - | RandomNormal _ as descr -> create descr - | ReduceSum { input; axes; keepdims; noop_with_empty_axes } -> - create - (ReduceSum { input = f input; axes; keepdims; noop_with_empty_axes }) - | Gather { input; indices; axis } -> - create (Gather { input = f input; indices = f indices; axis }) - | GatherND { data; indices; batch_dims } -> - create (GatherND { data = f data; indices = f indices; batch_dims }) - | Transpose t -> create (Transpose { t with input = f t.input }) - | Flatten t -> create (Flatten { t with input = f t.input }) - | Identity { input } -> create (Identity { input = f input }) - | Concat { inputs; axis } -> - create (Concat { inputs = List.map ~f inputs; axis }) - | Gemm t -> - create - (Gemm - { - t with - inputA = f t.inputA; - inputB = f t.inputB; - inputC = Base.Option.map t.inputC ~f; - }) - | Squeeze t -> create (Squeeze { t with data = f t.data }) - | Reshape t -> create (Reshape { t with input = f t.input }) - | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> n (* todo *) - - (* let map_rec f node = let h = Base.Hashtbl.create (module Base.Int) in let - rec aux n = Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux - n)) in aux node *) - - let replace_input f node = - let h = Base.Hashtbl.create (module Base.Int) in - let rec aux n = - Base.Hashtbl.find_or_add h n.id ~default:(fun () -> - match n.descr with Input _ -> f () | _ -> map aux n) - in - aux node - - (* iter on the nodes accessible from [node] ([node] comprised) without - repetition *) - let map_rec f node = - let h = Base.Hashtbl.create (module Base.Int) in - let rec aux n = - Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n)) - in - aux node - - let iter_rec f node = - let h = Base.Hashtbl.create (module Base.Int) in - let rec aux n = - Base.Hashtbl.find_or_add h n.id ~default:(fun () -> - List.iter ~f:aux (preds n); - f n) - in - aux node -end - -(** TODO: some other invariants must be checked e.g only one input *) -let create output = - let succs = Base.Hashtbl.create (module Base.Int) in - let check_node node = - List.iter - ~f:(fun p -> Base.Hashtbl.add_multi succs ~key:p.id ~data:node) - (Node.preds node) - in - (* Add the key of the output nodes so that all the nodes are in succs *) - Base.Hashtbl.add_exn succs ~key:output.id ~data:[]; - Node.iter_rec check_node output; - { output; succs } - -let input_shape g = - let r = ref None in - let check_node n = - match n.descr with Input { shape } -> r := Some shape | _ -> () - in - Node.iter_rec check_node g.output; - Option.value_exn !r - -let succs t node = Base.Hashtbl.find_exn t.succs node.id -let iter_vertex f t = Node.iter_rec f t.output -let iter_succ f t node = List.iter ~f (Base.Hashtbl.find_multi t.succs node.id) -let pp fmt t = iter_vertex (fun v -> Fmt.pf fmt "@[%a@]@ " pp_descr v.descr) t - -let pp_debug fmt t = - iter_vertex - (fun v -> - Fmt.pf fmt "@[%i: %a(%a) : %a@]@ " v.id pp_descr v.descr - Fmt.(list ~sep:comma (using (fun x -> x.id) int)) - (Node.preds v) Shape.pp v.shape) - t - -let nodes t = - let l = ref [] in - iter_vertex (fun v -> l := v :: !l) t; - !l - -module M = Graph.Topological.Make - -module GFloat = struct - type nonrec t = t - - let iter_edges_e f t = - iter_vertex (fun n -> List.iter ~f:(fun n' -> f (n', n)) (Node.preds n)) t - - module Node = struct - type t = Node.t - - let compare = Node.compare - let equal = Node.equal - let hash = Node.hash - let sexp_of_t = Node.sexp_of_t - let create = Node.create - let show = Node.show - end - - module V = Node - - module E = struct - type t = V.t * V.t - - let src = fst - let dst = snd - end - - let iter_vertex = iter_vertex - let iter_succ = iter_succ -end - -module Dot = Graph.Graphviz.Dot (struct - include GFloat (* use the graph module from above *) - - let node_label (v : Node.t) = Node.show v - let edge_attributes (_, _) = [] - let default_edge_attributes _ = [] - let get_subgraph _ = None - let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ] - let vertex_name (v : Node.t) = Int.to_string v.id - let default_vertex_attributes _ = [] - let graph_attributes _ = [] -end) - -let grapheasy g = - try - let cin, cout = - Unix.open_process_args "graph-easy" - [| "graph-easy"; "--from=graphviz"; "--as=boxart" |] - in - Dot.output_graph cout g; - Stdlib.close_out cout; - let ascii = Stdio.In_channel.input_all cin in - ignore (Unix.close_process (cin, cout)); - ascii - with exn -> - Fmt.str "Error graph-easy call: %s" (Stdlib.Printexc.to_string exn) diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli deleted file mode 100644 index d1a6d322c5af7ffdd7e237f928a0cf974b2b8aff..0000000000000000000000000000000000000000 --- a/lib/ir/nier_simple.mli +++ /dev/null @@ -1,251 +0,0 @@ -module Shape : sig - type t [@@deriving show, ord, eq] - - val to_array : t -> int array - val of_array : int array -> t - val to_list : t -> int list - val of_list : int list -> t - val rank : t -> int - val size : t -> int - val get : t -> int -> int - val set : t -> int -> int -> t - val row_major : t -> int array -> int - val unrow_major : t -> int -> int array -end - -module Tensor : sig - (** Immutable tensors *) - - type ('a, 'b) t - - val of_tensor : ('a, 'b) Nier_cfg.Tensor.t -> ('a, 'b) t - val to_tensor : ('a, 'b) t -> ('a, 'b) Nier_cfg.Tensor.t - val create_1_float : float -> (float, Bigarray.float64_elt) t - val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t - val shape : ('a, 'b) t -> Shape.t - val flatten : ('a, 'b) t -> 'a list - - val of_array1 : - Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t - - val get : ('a, 'b) t -> int array -> 'a -end - -module GenTensor : sig - type t = - | Float of (float, Bigarray.float64_elt) Tensor.t - | Int64 of (int64, Bigarray.int64_elt) Tensor.t - - val create_1_float : float -> t - val create_1_int64 : int64 -> t - val shape : t -> Shape.t - val of_int_array : int array -> t -end - -type descr = - | Constant of { data : GenTensor.t } - | Add of { - input1 : node; - input2 : node; - } - | Sub of { - input1 : node; - input2 : node; - } - | Mul of { - input1 : node; - input2 : node; - } - | Div of { - input1 : node; - input2 : node; - } - | Matmul of { - input1 : node; - input2 : node; - } - | Gemm of { - inputA : node; - inputB : node; - inputC : node option; - alpha : float; - beta : float; - transA : bool; - transB : bool; - } - | LogSoftmax - | ReLu of { input : node } - | Transpose of { - input : node; - (* called "data" in ONNX documentation : - https://onnx.ai/onnx/operators/onnx__Transpose.html*) - perm : int list; - } - | Squeeze of { - data : node; - axes : node option; (* int64 *) - } - | MaxPool - | Conv - | Reshape of { - input : node; - shape : node; (* int64 *) - } - | Flatten of { - input : node; - axis : int; - } - | Identity of { input : node } - | Input of { shape : Shape.t } - | RW_Linearized_ReLu - | Concat of { - inputs : node list; - axis : Base.int; - } - | Gather of { - input : node; - indices : node; - axis : int; - } - | ReduceSum of { - input : node; - axes : node option; - keepdims : int; - noop_with_empty_axes : int; - } - | GatherND of { - data : node; - indices : node; - batch_dims : int; - } - | RandomNormal of { - dtype : int; - mean : float; - scale : float; - seed : float; - shape : int array; - } - | Abs of { input : node } - | Log of { input : node } - -and node = private { - id : int; (* unique identifier *) - descr : descr; - shape : Shape.t; - ty : ty; -} - -and ty = - | Float - | Int64 - -module Node : sig - type t = node [@@deriving show] - - val equal : t -> t -> bool - - include Base.Hashtbl.Key.S with type t := t - include Base.Comparator.S with type t := t - - val create : descr -> t - - val gather_int : ?encode:bool -> t -> int -> t - (** create a node by selection at a given index. *) - (* Implemented via a [Matmul] if [encode] (true by default). - - TODO: [encode] should be not be a parameter, rather depend on prover. *) - - val mul_float : t -> float -> t - (* Implemented via a [Matmul]. *) - - val div_float : ?encode:bool -> t -> float -> t - (* Implemented via a [Matmul] if [encode] (true by default). - - TODO: [encode] should be not be a parameter, rather depend on prover. *) - - val constant_int_array : int array -> t - (** create a node for a constant array *) - - val reshape : Shape.t -> t -> t - (** create if necessary a reshape node *) - - val concat_0 : t list -> t - (** create if necessary a concat node for the first axis *) - - val map : (t -> t) -> t -> t - (** [map f n] replace the direct inputs [i] of n by [f i] *) - - val map_rec : (t -> t) -> t -> t - (** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by - [f i] *) - - val replace_input : (unit -> t) -> t -> t - (** [replace_input f n] replace the input in [n] by [f ()] *) - - val preds : t -> t list - (** Direct predecessors of a node *) - - val iter_rec : (t -> unit) -> t -> unit - (** Iterate on the predecessors of a node and itself. Repect topological - order. *) - - val compute_shape : t -> Shape.t -end - -type t - -val pp : t Fmt.t -val pp_debug : t Fmt.t - -val create : node -> t -(** Create a network from its output node, it must have only one input *) - -val output : t -> node -(** Output node of the network *) - -val input_shape : t -> Shape.t -(** Input shape of the network *) - -val succs : t -> node -> node list -(** successors of a node *) - -val iter_vertex : (Node.t -> unit) -> t -> unit -val iter_succ : (Node.t -> unit) -> t -> Node.t -> unit -val nodes : t -> Node.t list - -(** Respect some OcamlGraph signature *) -module GFloat : sig - type nonrec t = t - - module Node : sig - type t = node - - val equal : t -> t -> bool - val compare : t -> t -> int - val hash : t -> int - val create : descr -> t - val sexp_of_t : t -> Sexplib0.Sexp.t - end - - module V = Node - - module E : sig - type t = V.t * V.t - - val src : t -> V.t - val dst : t -> V.t - end - - val iter_vertex : (V.t -> unit) -> t -> unit - val iter_succ : (V.t -> unit) -> t -> V.t -> unit - val iter_edges_e : (E.t -> unit) -> t -> unit -end - -module Dot : sig - val fprint_graph : Format.formatter -> GFloat.t -> unit - val output_graph : out_channel -> GFloat.t -> unit -end - -val grapheasy : t -> string -(** @return - ASCII representation of the graph using "graph-easy" external command *) diff --git a/lib/ir/dune b/lib/nir/dune similarity index 88% rename from lib/ir/dune rename to lib/nir/dune index 14a9bd69afbf4d9e5a36e6dc4b88b4ef669ddcc7..7a7030df19a39b1f7a548ed832a1ce43d79aa9b0 100644 --- a/lib/ir/dune +++ b/lib/nir/dune @@ -1,6 +1,6 @@ (library - (name ir) - (public_name caisar.ir) + (name nir) + (public_name caisar.nir) (preprocess (pps ppx_inline_test diff --git a/lib/nir/gentensor.ml b/lib/nir/gentensor.ml new file mode 100644 index 0000000000000000000000000000000000000000..5a2eae003d00794fab5a6b7283f1692a7a177f21 --- /dev/null +++ b/lib/nir/gentensor.ml @@ -0,0 +1,37 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) +open Base + +type t = + | Float of (float, Bigarray.float64_elt) Tensor.t + | Int64 of (int64, Bigarray.int64_elt) Tensor.t + +let create_1_float f = Float (Tensor.create_1_float f) +let create_1_int64 i = Int64 (Tensor.create_1_int64 i) + +let of_int_array a = + Int64 + (Tensor.of_array1 + (Shape.of_array [| Array.length a |]) + (Bigarray.Array1.of_array Int64 C_layout (Array.map a ~f:Int64.of_int))) + +let shape = function Float f -> Tensor.shape f | Int64 i -> Tensor.shape i diff --git a/lib/onnx/onnx.mli b/lib/nir/gentensor.mli similarity index 77% rename from lib/onnx/onnx.mli rename to lib/nir/gentensor.mli index d7e52c2872088defca6c63cf1ffc5ed3bb566788..9c9e1c31e4406063d4bb7767b9f2daa6742cee72 100644 --- a/lib/onnx/onnx.mli +++ b/lib/nir/gentensor.mli @@ -20,19 +20,11 @@ (* *) (**************************************************************************) -module G = Ir.Nier_cfg.NierCFGFloat +type t = + | Float of (float, Bigarray.float64_elt) Tensor.t + | Int64 of (int64, Bigarray.int64_elt) Tensor.t -type t = private { - n_inputs : int; (** Number of inputs. *) - n_outputs : int; (** Number of outputs. *) - nier : (G.t, string) Result.t; (** Intermediate representation. *) -} -(** ONNX model metadata and intermediate representation. *) - -val parse : string -> (t, string) Result.t -(** Parse an ONNX file into a NIER. *) - -val write : G.t -> string -> unit -(** Write a NIER into an ONNX file. *) - -module Simple = Simple +val create_1_float : float -> t +val create_1_int64 : int64 -> t +val of_int_array : int array -> t +val shape : t -> Shape.t diff --git a/lib/nir/ngraph.ml b/lib/nir/ngraph.ml new file mode 100644 index 0000000000000000000000000000000000000000..9abfd51034d82b82ff9d98c64e82bc67e6534bc6 --- /dev/null +++ b/lib/nir/ngraph.ml @@ -0,0 +1,120 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) +open Base + +type t = { + output : Node.t; + succs : (int, Node.t list) Base.Hashtbl.t; +} + +let output t = t.output + +(** TODO: some other invariants must be checked e.g only one input *) +let create output = + let succs = Base.Hashtbl.create (module Base.Int) in + let check_node node = + List.iter + ~f:(fun p -> Base.Hashtbl.add_multi succs ~key:p.id ~data:node) + (Node.preds node) + in + (* Add the key of the output nodes so that all the nodes are in succs *) + Base.Hashtbl.add_exn succs ~key:output.Node.id ~data:[]; + Node.iter_rec check_node output; + { output; succs } + +let input_shape g = + let r = ref None in + let check_node n = + match n.Node.descr with Input { shape } -> r := Some shape | _ -> () + in + Node.iter_rec check_node g.output; + Option.value_exn !r + +let succs t node = Base.Hashtbl.find_exn t.succs node.Node.id +let iter_vertex f t = Node.iter_rec f t.output + +let iter_succ f t node = + List.iter ~f (Base.Hashtbl.find_multi t.succs node.Node.id) + +let pp fmt t = + iter_vertex (fun v -> Fmt.pf fmt "@[%a@]@ " Node.pp_descr v.descr) t + +let pp_debug fmt t = + iter_vertex + (fun v -> + Fmt.pf fmt "@[%i: %a(%a) : %a@]@ " v.id Node.pp_descr v.descr + Fmt.(list ~sep:comma (using (fun x -> x.Node.id) int)) + (Node.preds v) Shape.pp v.shape) + t + +let nodes t = + let l = ref [] in + iter_vertex (fun v -> l := v :: !l) t; + !l + +module M = Graph.Topological.Make + +module GFloat = struct + type nonrec t = t + + let iter_edges_e f t = + iter_vertex (fun n -> List.iter ~f:(fun n' -> f (n', n)) (Node.preds n)) t + + module V = Node + + module E = struct + type t = V.t * V.t + + let src = fst + let dst = snd + end + + let iter_vertex = iter_vertex + let iter_succ = iter_succ +end + +module Dot = Graph.Graphviz.Dot (struct + include GFloat (* use the graph module from above *) + + let node_label (v : Node.t) = Node.show v + let edge_attributes (_, _) = [] + let default_edge_attributes _ = [] + let get_subgraph _ = None + let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ] + let vertex_name (v : Node.t) = Int.to_string v.id + let default_vertex_attributes _ = [] + let graph_attributes _ = [] +end) + +let grapheasy g = + try + let cin, cout = + Unix.open_process_args "graph-easy" + [| "graph-easy"; "--from=graphviz"; "--as=boxart" |] + in + Dot.output_graph cout g; + Stdlib.close_out cout; + let ascii = Stdio.In_channel.input_all cin in + ignore (Unix.close_process (cin, cout)); + ascii + with exn -> + Fmt.str "Error graph-easy call: %s" (Stdlib.Printexc.to_string exn) diff --git a/lib/nir/ngraph.mli b/lib/nir/ngraph.mli new file mode 100644 index 0000000000000000000000000000000000000000..a96ea8cdf229ef48482adc9524c22612f6ace6b4 --- /dev/null +++ b/lib/nir/ngraph.mli @@ -0,0 +1,74 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +(** {1 Neural Intermediate Representation (NIR)} *) + +(** NIR is a graph describing a machine learning model control flow. + + A graph is described starting from its output node. *) + +type t + +val pp : t Fmt.t +val pp_debug : t Fmt.t + +val create : Node.t -> t +(** Create a network from its output node.t *) + +val output : t -> Node.t +(** Output node.t of the network *) + +val nodes : t -> Node.t list +(** Output nodes of the network *) + +val input_shape : t -> Shape.t +(** Input shape of the network *) + +val succs : t -> Node.t -> Node.t list +(** successors of a node.t *) + +val iter_vertex : (Node.t -> unit) -> t -> unit +val iter_succ : (Node.t -> unit) -> t -> Node.t -> unit +val grapheasy : t -> string + +(** Respect some OcamlGraph signature *) +module GFloat : sig + type nonrec t = t + + module V = Node + + module E : sig + type t = V.t * V.t + + val src : t -> V.t + val dst : t -> V.t + end + + val iter_vertex : (V.t -> unit) -> t -> unit + val iter_succ : (V.t -> unit) -> t -> V.t -> unit + val iter_edges_e : (E.t -> unit) -> t -> unit +end + +module Dot : sig + val fprint_graph : Format.formatter -> GFloat.t -> unit + val output_graph : out_channel -> GFloat.t -> unit +end diff --git a/lib/nir/node.ml b/lib/nir/node.ml new file mode 100644 index 0000000000000000000000000000000000000000..656d928f12fe52ed3a5d1d5c0b804622fb148b4b --- /dev/null +++ b/lib/nir/node.ml @@ -0,0 +1,479 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) +open Base + +type ty = + | Float + | Int64 +[@@deriving show] + +(** TODO: add the information needed to compute the shape *) +type descr = + | Constant of { data : Gentensor.t } + | Add of { + input1 : t; + input2 : t; + } + | Sub of { + input1 : t; + input2 : t; + } + | Mul of { + input1 : t; + input2 : t; + } + | Div of { + input1 : t; + input2 : t; + } + | Matmul of { + input1 : t; + input2 : t; + } + | Gemm of { + inputA : t; + inputB : t; + inputC : t option; + alpha : float; + beta : float; + transA : bool; + transB : bool; + } + | LogSoftmax + | ReLu of { input : t } + | Transpose of { + input : t; + (* called "data" in ONNX documentation : + https://onnx.ai/onnx/operators/onnx__Transpose.html*) + perm : int list; + } + | Squeeze of { + data : t; + axes : t option; (* data int64 *) + } + | MaxPool + | Conv + | Reshape of { + input : t; + shape : t; (* data int64 *) + } + | Flatten of { + input : t; + axis : int; + } + | Identity of { input : t } + | Input of { shape : Shape.t } + | RW_Linearized_ReLu + | Concat of { + inputs : t list; + axis : int; + } + | Gather of { + input : t; + indices : t; + axis : int; + } + | ReduceSum of { + input : t; + axes : t option; + keepdims : int; + noop_with_empty_axes : int; + } + | GatherND of { + data : t; + indices : t; + batch_dims : int; + } + | RandomNormal of { + dtype : int; + mean : float; + scale : float; + seed : float; + shape : int array; + } + | Abs of { input : t } + | Log of { input : t } + +and t = { + id : int; + descr : descr; [@printer fun fmt d -> pp_descr fmt d] + shape : Shape.t; + ty : ty; +} + +let pp_descr fmt p = + match p with + | Input { shape } -> Fmt.pf fmt "Input: %a" Shape.pp shape + | Transpose { perm; _ } -> + Fmt.pf fmt "Transpose: [%a]" Fmt.(list ~sep:semi int) perm + | Constant { data = Int64 b } when Shape.size (Tensor.shape b) < 3 -> + Fmt.pf fmt "Constant[%a]" Fmt.(list ~sep:comma int64) (Tensor.flatten b) + | Constant _ -> Fmt.pf fmt "Constant" + | Add _ -> Fmt.pf fmt "Add" + | Sub _ -> Fmt.pf fmt "Sub" + | Mul _ -> Fmt.pf fmt "Mul" + | Div _ -> Fmt.pf fmt "Div" + | Matmul _ -> Fmt.pf fmt "Matmul" + | Gemm _ -> Fmt.pf fmt "Gemm" + | LogSoftmax -> Fmt.pf fmt "LogSoftmax" + | ReLu _ -> Fmt.pf fmt "ReLu" + | Squeeze _ -> Fmt.pf fmt "Squeeze" + | MaxPool -> Fmt.pf fmt "MaxPool" + | Conv -> Fmt.pf fmt "Conv" + | Reshape _ -> Fmt.pf fmt "Reshape" + | Flatten _ -> Fmt.pf fmt "Flatten" + | Identity _ -> Fmt.pf fmt "Identity" + | RW_Linearized_ReLu -> Fmt.pf fmt "RW_Linearized_ReLu" + | Concat { axis; _ } -> Fmt.pf fmt "Concat[%i]" axis + | Gather _ -> Fmt.pf fmt "Gather" + | ReduceSum _ -> Fmt.pf fmt "ReduceSum" + | GatherND _ -> Fmt.pf fmt "GatherND" + | RandomNormal _ -> Fmt.pf fmt "RandomNormal" + | Abs _ -> Fmt.pf fmt "Abs" + | Log _ -> Fmt.pf fmt "Log" + +let show_descr t = Fmt.str "%a" pp_descr t +let compare { id = id1; _ } { id = id2; _ } = Int.compare id1 id2 +let equal { id = id1; _ } { id = id2; _ } = Int.equal id1 id2 +let hash { id; _ } = id +let sexp_of_t node = Base.Int.sexp_of_t node.id +let pp fmt n = Fmt.pf fmt "@[%i: %a@]" n.id pp_descr n.descr +let show n = Fmt.str "%a" pp n + +include Base.Comparator.Make (struct + type nonrec t = t + + let compare = compare + let sexp_of_t = sexp_of_t +end) + +let rec compute_shape n = n.shape + +and compute_shape_descr = function + | Add { input1; _ } + | Div { input1; _ } + | Mul { input1; _ } + | Sub { input1; _ } -> + compute_shape input1 + | Flatten { input; axis } -> + (* (d_0 X d_1 … d_(axis-1), d_axis X d_(axis+1) … X dn). *) + let shape = compute_shape input in + let d1 = ref 1 in + let d2 = ref 1 in + for i = 0 to axis - 1 do + d1 := !d1 * Shape.get shape i + done; + for i = axis to Shape.rank shape - 1 do + d2 := !d1 * Shape.get shape i + done; + Shape.of_list [ !d1; !d2 ] + | Input { shape } -> shape + | ReLu { input } -> compute_shape input + | Transpose { input; perm = [] } -> + compute_shape input |> Shape.to_list |> List.rev |> Shape.of_list + | Transpose { input; perm } -> + let shape = compute_shape input in + let rank = Shape.rank shape in + assert (Int.equal rank (List.length perm)); + let shape' = Array.create ~len:rank 0 in + let shape = Shape.to_array shape in + Base.List.iteri perm ~f:(fun i j -> Array.set shape' i (Array.get shape j)); + Shape.of_array shape' + | Constant { data } -> Gentensor.shape data + | Concat { inputs; axis } -> + let shapes = List.map ~f:compute_shape inputs in + let shape = List.hd_exn shapes in + let axis = if axis < 0 then Shape.rank shape + axis else axis in + let l = List.map ~f:(fun s -> Shape.get s axis) shapes in + let i = List.reduce_exn ~f:( + ) l in + Shape.set shape axis i + | Gather { input; indices; axis } -> ( + let input_shape = compute_shape input in + let indices_shape = compute_shape indices in + let axis = if axis < 0 then Shape.rank input_shape + axis else axis in + match List.split_n (Shape.to_list input_shape) axis with + | _, [] -> failwith "axis is bigger than shape rank" + | before, _ :: after -> + Shape.of_list (before @ Shape.to_list indices_shape @ after)) + | Matmul { input1; input2 } -> + let pad_left = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ 1; a ] + | x -> x + in + let pad_right = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ a; 1 ] + | x -> x + in + let rec one_padding l i = + if i <= 0 then l else one_padding (1 :: l) (i - 1) + in + (* Expected semantic: + * Matrix multiplication C = AB + * A (shape [n;m]); B (shape [m;p]); C (shape [n;p]) + * shape of b: b_sh + * shape of a: a_sh + * shape of c: c_sh + *) + let check_matmul_size_ab ~a_sh ~b_sh = + let adim2 = pad_left a_sh in + let bdim2 = pad_right b_sh in + let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in + let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in + let rec infer_csize acc ad bd = + match (ad, bd) with + | [ m; n ], [ nn; p ] -> + if nn = n + then List.rev_append acc [ m; p ] + else failwith "size of matrices not adequate" + | a :: la, b :: lb -> + if a = b + then infer_csize (a :: acc) la lb + else if a = 1 + then infer_csize (b :: acc) la lb + else if b = 1 + then infer_csize (a :: acc) la lb + else failwith "Checking matmul_size failed: one discordance" + | _, _ -> failwith "Checking matmul_size failed" + in + infer_csize [] adim bdim + in + (* TODO: in case of pad_left and pad_right remove the added dimension. But + it is not clear what must be done when broadcasting is done *) + Shape.of_list + (check_matmul_size_ab + ~a_sh:(Shape.to_list (compute_shape input1)) + ~b_sh:(Shape.to_list (compute_shape input2))) + | Reshape { input; shape; _ } -> + let shape = + match shape.descr with + | Constant { data = Int64 a } -> + List.map ~f:Int64.to_int_exn (Tensor.flatten a) + | _ -> + (* Some constant propagation could be useful in some cases eg. patch-1 + VNNcomp *) + failwith "non-constant shape in reshape not supported" + in + List.iter shape ~f:(function + | -1 | 0 -> failwith "not implemented 0 -1 in shape for reshape" + | _ -> ()); + let out = Shape.of_list shape in + if Shape.size out <> Shape.size input.shape + then + failwith + "Reshape: shape of input and shape given have not the same number of \ + elements"; + out + | Gemm { inputA; inputB; inputC = _; alpha = _; beta = _; transA; transB } -> + let rank2 i = + match Shape.to_array_unsafe i.shape with + | [| k; n |] -> (k, n) + | _ -> failwith "Gemm input must be of size 2" + in + let tr trans (k, n) = if trans then (n, k) else (k, n) in + let a1, a2 = tr transA @@ rank2 inputA in + let b1, b2 = tr transB @@ rank2 inputB in + if not (Int.equal a2 b1) + then + Fmt.failwith "Gemm (M:%i,K:%i) (K:%i,N:%i) -> (M:%i,N:%i)" a1 a2 b1 b2 a1 + b2; + Shape.of_array [| a1; b2 |] + | ( LogSoftmax | Squeeze _ | MaxPool | Conv | Identity _ | RW_Linearized_ReLu + | ReduceSum _ | GatherND _ | RandomNormal _ | Abs _ | Log _ ) as n -> + failwith (Fmt.str "todo compute shape : %a" pp_descr n) + +let compute_ty : _ -> ty = function + | Constant { data = Float _ } -> Float + | Constant { data = Int64 _ } -> Int64 + | _ -> Float + +let create = + let c = ref (-1) in + fun descr -> + Int.incr c; + { id = !c; descr; shape = compute_shape_descr descr; ty = compute_ty descr } + +let constant_int_array a = create (Constant { data = Gentensor.of_int_array a }) + +let reshape shape node = + if Shape.equal node.shape shape + then node + else + create + (Reshape + { input = node; shape = constant_int_array (Shape.to_array shape) }) + +let gather_int_as_matmul input i = + let input1 = reshape (Shape.of_array [| 1; Shape.size input.shape |]) input in + let selector = Array.create ~len:(Shape.size input1.shape) Float.zero in + Array.set selector i Float.one; + let selector = + Gentensor.Float + (Tensor.of_array1 + (Shape.of_array [| Array.length selector; 1 |]) + (Bigarray.Array1.of_array Float64 C_layout selector)) + in + let input2 = create (Constant { data = selector }) in + let result = create (Matmul { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + +let gather_int ?(encode = true) input i = + if encode + then gather_int_as_matmul input i + else + let indices = + create (Constant { data = Gentensor.create_1_int64 (Int64.of_int i) }) + in + create (Gather { input; indices; axis = 0 }) + +let mul_float input f = + let input1 = reshape (Shape.of_array [| 1; 1 |]) input in + let f = Array.create ~len:1 f in + let f = + Gentensor.Float + (Tensor.of_array1 + (Shape.of_array [| Array.length f; 1 |]) + (Bigarray.Array1.of_array Float64 C_layout f)) + in + let input2 = create (Constant { data = f }) in + let result = create (Matmul { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + +let div_float ?(encode = true) input f = + if encode + then + let f = Float.one /. f in + mul_float input f + else + let input1 = reshape (Shape.of_array [| 1; 1 |]) input in + let f = Array.create ~len:1 f in + let f = + Gentensor.Float + (Tensor.of_array1 + (Shape.of_array [| Array.length f; 1 |]) + (Bigarray.Array1.of_array Float64 C_layout f)) + in + let input2 = create (Constant { data = f }) in + let result = create (Div { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + +let concat_0 = function + | [ n ] -> n + | [] -> failwith "empty concat" + | inputs -> create (Concat { inputs; axis = 0 }) + +let preds node = + match node.descr with + | Constant _ | Input _ -> [] + | Add { input1; input2 } + | Sub { input1; input2 } + | Mul { input1; input2 } + | Div { input1; input2 } + | Matmul { input1; input2 } -> + [ input1; input2 ] + | Gather { input; indices; axis = _ } -> [ input; indices ] + | GatherND { data; indices; batch_dims = _ } -> [ data; indices ] + | ReLu { input } | Abs { input } | Log { input } -> [ input ] + | Concat { inputs; axis = _ } -> inputs + | ReduceSum { input; axes = Some x; _ } -> [ input; x ] + | ReduceSum { input; axes = None; _ } -> [ input ] + | RandomNormal _ -> [] + | Transpose { input; _ } -> [ input ] + | Flatten { input; _ } -> [ input ] + | Identity { input } -> [ input ] + | Gemm { inputA; inputB; inputC = Some x; _ } -> [ inputA; inputB; x ] + | Gemm { inputA; inputB; inputC = None; _ } -> [ inputA; inputB ] + | Squeeze { data; _ } -> [ data ] + | Reshape { input; shape; _ } -> [ input; shape ] + | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> [] + +let map f n = + match n.descr with + | Constant _ | Input _ -> n + | Add { input1; input2 } -> + create (Add { input1 = f input1; input2 = f input2 }) + | Sub { input1; input2 } -> + create (Sub { input1 = f input1; input2 = f input2 }) + | Mul { input1; input2 } -> + create (Mul { input1 = f input1; input2 = f input2 }) + | Div { input1; input2 } -> + create (Div { input1 = f input1; input2 = f input2 }) + | Matmul { input1; input2 } -> + create (Matmul { input1 = f input1; input2 = f input2 }) + | ReLu { input } -> create (ReLu { input = f input }) + | Abs { input } -> create (Abs { input = f input }) + | Log { input } -> create (Log { input = f input }) + | RandomNormal _ as descr -> create descr + | ReduceSum { input; axes; keepdims; noop_with_empty_axes } -> + create (ReduceSum { input = f input; axes; keepdims; noop_with_empty_axes }) + | Gather { input; indices; axis } -> + create (Gather { input = f input; indices = f indices; axis }) + | GatherND { data; indices; batch_dims } -> + create (GatherND { data = f data; indices = f indices; batch_dims }) + | Transpose t -> create (Transpose { t with input = f t.input }) + | Flatten t -> create (Flatten { t with input = f t.input }) + | Identity { input } -> create (Identity { input = f input }) + | Concat { inputs; axis } -> + create (Concat { inputs = List.map ~f inputs; axis }) + | Gemm t -> + create + (Gemm + { + t with + inputA = f t.inputA; + inputB = f t.inputB; + inputC = Base.Option.map t.inputC ~f; + }) + | Squeeze t -> create (Squeeze { t with data = f t.data }) + | Reshape t -> create (Reshape { t with input = f t.input }) + | LogSoftmax | MaxPool | Conv | RW_Linearized_ReLu -> n (* todo *) + +(* let map_rec f node = let h = Base.Hashtbl.create (module Base.Int) in let rec + aux n = Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n)) in + aux node *) + +let replace_input f node = + let h = Base.Hashtbl.create (module Base.Int) in + let rec aux n = + Base.Hashtbl.find_or_add h n.id ~default:(fun () -> + match n.descr with Input _ -> f () | _ -> map aux n) + in + aux node + +(* iter on the nodes accessible from [node] ([node] comprised) without + repetition *) +let map_rec f node = + let h = Base.Hashtbl.create (module Base.Int) in + let rec aux n = + Base.Hashtbl.find_or_add h n.id ~default:(fun () -> f (map aux n)) + in + aux node + +let iter_rec f node = + let h = Base.Hashtbl.create (module Base.Int) in + let rec aux n = + Base.Hashtbl.find_or_add h n.id ~default:(fun () -> + List.iter ~f:aux (preds n); + f n) + in + aux node diff --git a/lib/nir/node.mli b/lib/nir/node.mli new file mode 100644 index 0000000000000000000000000000000000000000..e23d5a57ffaf9929c7112607eecb1b43a456dacb --- /dev/null +++ b/lib/nir/node.mli @@ -0,0 +1,171 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +(** {1 Nodes descriptions} *) + +(** A node is composed of + + - a unique [id] of type int + - a node description of type [descr] + + [descr] describes several operations. When an operation shares the same name + as an ONNX operation, it follows the standard defined in the ONNX IR v8 and + ONNX Opset v13 standards, described here: + https://onnx.ai/onnx/operators/index.html. + + Nodes only require their inputs: it is assumed that a node only returns one + value. *) + +open Base + +type ty = + | Float + | Int64 +[@@deriving show] + +type descr = + | Constant of { data : Gentensor.t } + (** A constant tensor, used to store non-varying parameters during + inference. *) + | Add of { + input1 : t; + input2 : t; + } + | Sub of { + input1 : t; + input2 : t; + } + | Mul of { + input1 : t; + input2 : t; + } + | Div of { + input1 : t; + input2 : t; + } + | Matmul of { + input1 : t; + input2 : t; + } + | Gemm of { + inputA : t; + inputB : t; + inputC : t option; + alpha : float; + beta : float; + transA : bool; + transB : bool; + } + | LogSoftmax + | ReLu of { input : t } + | Transpose of { + input : t; + (** Called "data" in ONNX documentation : + https://onnx.ai/onnx/operators/onnx__Transpose.html .*) + perm : int list; + } + | Squeeze of { + data : t; + axes : t option; (* Expects a int64 . *) + } + | MaxPool + | Conv + | Reshape of { + input : t; + shape : t; (* Expects a int64 *) + } + | Flatten of { + input : t; + axis : int; + } + | Identity of { input : t } + | Input of { shape : Shape.t } + | RW_Linearized_ReLu + | Concat of { + inputs : t list; + axis : Base.int; + } + | Gather of { + input : t; + indices : t; + axis : int; + } + | ReduceSum of { + input : t; + axes : t option; + keepdims : int; + noop_with_empty_axes : int; + } + | GatherND of { + data : t; + indices : t; + batch_dims : int; + } + | RandomNormal of { + dtype : int; + mean : float; + scale : float; + seed : float; + shape : int array; + } + | Abs of { input : t } + | Log of { input : t } +[@@deriving show] + +and t = private { + id : int; + descr : descr; + shape : Shape.t; + ty : ty; (** Describes the shape of the result of the node computation. *) +} + +val equal : t -> t -> bool + +include Base.Hashtbl.Key.S with type t := t +include Base.Comparator.S with type t := t + +val create : descr -> t +(** [create descr] returns a value of type node with proper indexing and the + shape according to the ONNX semantic. *) + +val gather_int : ?encode:bool -> t -> int -> t + +val map : (t -> t) -> t -> t +(** [map f n] replace the direct inputs [i] of n by [f i] *) + +val map_rec : (t -> t) -> t -> t +(** [map_rec f n] replace top-bottom the nodes [i] accessible from [n] by [f i] *) + +val replace_input : (unit -> t) -> t -> t +(** [replace_input f n] replace the input in [n] by [f ()] *) + +val preds : t -> t list +(** Direct predecessors of a t *) + +val iter_rec : (t -> unit) -> t -> unit +(** Iterate on the predecessors of a t and itself. Repect topological order. *) + +val compute_shape : t -> Shape.t +val mul_float : t -> float -> t +val div_float : ?encode:bool -> t -> float -> t +val concat_0 : t list -> t +val reshape : Shape.t -> t -> t diff --git a/lib/nir/shape.ml b/lib/nir/shape.ml new file mode 100644 index 0000000000000000000000000000000000000000..b07e361922a9549b1834803fc43b00418abe736e --- /dev/null +++ b/lib/nir/shape.ml @@ -0,0 +1,58 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Base + +type t = int array [@@deriving ord, eq] + +let to_array = Array.copy +let to_array_unsafe x = x +let of_array = Array.copy +let to_list = Array.to_list +let of_list = Array.of_list +let rank = Array.length +let size t = Array.fold t ~f:( * ) ~init:1 +let pp fmt x = Fmt.pf fmt "[%a]" Fmt.(array ~sep:semi int) x +let show s = Fmt.str "%a" pp s +let get = Array.get + +let set t k v = + let t = Array.copy t in + Array.set t k v; + t + +let row_major t a = + assert (Array.length t = Array.length a); + let r = ref 0 in + for i = 0 to Array.length t - 1 do + r := (!r * t.(i)) + a.(i) + done; + !r + +let unrow_major t i = + let r = ref i in + let a = Array.create ~len:(Array.length t) 0 in + for i = Array.length t - 1 downto 0 do + a.(i) <- !r % t.(i); + r := !r / t.(i) + done; + a diff --git a/lib/nir/shape.mli b/lib/nir/shape.mli new file mode 100644 index 0000000000000000000000000000000000000000..79934bb842614e376ae7799cdf1ac21888b8c108 --- /dev/null +++ b/lib/nir/shape.mli @@ -0,0 +1,35 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +type t [@@deriving show, ord, eq] + +val to_array : t -> int array +val of_array : int array -> t +val to_list : t -> int list +val of_list : int list -> t +val rank : t -> int +val size : t -> int +val get : t -> int -> int +val set : t -> int -> int -> t +val row_major : t -> int array -> int +val unrow_major : t -> int -> int array +val to_array_unsafe : t -> int array diff --git a/lib/nir/tensor.ml b/lib/nir/tensor.ml new file mode 100644 index 0000000000000000000000000000000000000000..1eb1034a0eb664761d5f121abb32abd0f67897e7 --- /dev/null +++ b/lib/nir/tensor.ml @@ -0,0 +1,57 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) +open Base + +type ('a, 'b) t = ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t + +let copy t = + let t' = Bigarray.Genarray.(create (kind t) Bigarray.c_layout (dims t)) in + Bigarray.Genarray.blit t t'; + t' + +let of_tensor = copy +let to_tensor = copy + +let create_1_float v = + let t = + Bigarray.Genarray.(create Bigarray.float64 Bigarray.c_layout [| 1 |]) + in + Bigarray.Genarray.set t [| 0 |] v; + t + +let create_1_int64 v = + let t = Bigarray.Genarray.(create Bigarray.int64 Bigarray.c_layout [| 1 |]) in + Bigarray.Genarray.set t [| 0 |] v; + t + +let shape x = Shape.of_array @@ Bigarray.Genarray.dims x + +let flatten t = + let a = Bigarray.reshape_1 t (Shape.size (shape t)) in + List.init (Bigarray.Array1.dim a) ~f:(fun i -> Bigarray.Array1.get a i) + +let of_array1 shape t = + Bigarray.reshape + (copy @@ Bigarray.genarray_of_array1 t) + (Shape.to_array_unsafe shape) + +let get = Bigarray.Genarray.get diff --git a/lib/nir/tensor.mli b/lib/nir/tensor.mli new file mode 100644 index 0000000000000000000000000000000000000000..0d29ee42e95200586e6fb95dc473ef87fea36dbf --- /dev/null +++ b/lib/nir/tensor.mli @@ -0,0 +1,63 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +(** {1 Immutable tensor module} *) + +(** Tensors are multidimensional arrays used to represent numerical such as a + neural network paramters. + + This library relies on Bigarray.Genarray to instanciante tensors. *) + +(** [get t idx] returns the value in tensor [t] stored at coordinates [idx]. + Throw an error if the coordinate is invalid.*) + +(** [set_idx t idx v] sets value [v] for tensor [t] at [idx]. Throw an error if + the coordinate is invalid.*) + +type ('a, 'b) t + +val of_tensor : ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> ('a, 'b) t +val to_tensor : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t + +val create_1_float : float -> (float, Bigarray.float64_elt) t +(** [create_1_float f] returns an unidimentional tensor with one floating point + value [f]. *) + +val create_1_int64 : int64 -> (int64, Bigarray.int64_elt) t +(** [create_1_int64 i] returns an unidimentional tensor with one int64 value + [i]. *) + +val shape : ('a, 'b) t -> Shape.t + +val flatten : ('a, 'b) t -> 'a list +(** [flatten t] returns all values stored in [t] as a flat list. *) + +val of_array1 : + Shape.t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t +(* [of_array ] *) + +val get : ('a, 'b) t -> int array -> 'a +(** [get t sh] returns the value stored at coordinates [sh] in [t]. + + @raise Invalid_argument + if [sh] does not exactly match the shape of [t], or if [sh] is + out-of-bounds. *) diff --git a/lib/onnx/dune b/lib/onnx/dune index 1dda954451b84ed7d64565d34978d2541cc04fa0..782b1e9e717874a9b68a8d74b9c4946c67544ef7 100644 --- a/lib/onnx/dune +++ b/lib/onnx/dune @@ -6,7 +6,7 @@ stdio ocaml-protoc-plugin ocplib-endian - caisar.ir + caisar.nir caisar_logging) (synopsis "ONNX parser for CAISAR")) diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml deleted file mode 100644 index a647157b3620cccba1e603083b080d20b190be01..0000000000000000000000000000000000000000 --- a/lib/onnx/onnx.ml +++ /dev/null @@ -1,694 +0,0 @@ -(**************************************************************************) -(* *) -(* This file is part of CAISAR. *) -(* *) -(* Copyright (C) 2023 *) -(* CEA (Commissariat à l'énergie atomique et aux énergies *) -(* alternatives) *) -(* *) -(* You can redistribute it and/or modify it under the terms of the GNU *) -(* Lesser General Public License as published by the Free Software *) -(* Foundation, version 2.1. *) -(* *) -(* It is distributed in the hope that it will be useful, *) -(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) -(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) -(* GNU Lesser General Public License for more details. *) -(* *) -(* See the GNU Lesser General Public License version 2.1 *) -(* for more details (enclosed in the file licenses/LGPLv2.1). *) -(* *) -(**************************************************************************) - -open Base -module Format = Stdlib.Format -module Fun = Stdlib.Fun -module Oproto = Onnx_protoc (* Autogenerated during compilation *) -module Oprotom = Oproto.Onnx.ModelProto -module NCFG = Ir.Nier_cfg -module G = NCFG.NierCFGFloat -module NSimple = Ir.Nier_simple -module GFloat = Ir.Nier_simple.GFloat -module Dot = Ir.Nier_simple.Dot - -exception ParseError of string - -type t = { - n_inputs : int; (* Number of inputs. *) - n_outputs : int; (* Number of outputs. *) - nier : (G.t, string) Result.t; (* Intermediate representation. *) -} - -(* ONNX format handling. *) -type op_attribute = Oproto.Onnx.AttributeProto.t - -type tensordata = - | Raw of bytes - | Float of float list - -let (no_attr : op_attribute) = - { - name = None; - ref_attr_name = None; - doc_string = None; - type' = None; - f = None; - i = None; - s = None; - t = None; - g = None; - floats = []; - ints = []; - strings = []; - tensors = []; - graphs = []; - sparse_tensor = None; - tp = None; - sparse_tensors = []; - type_protos = []; - } - -let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) = - match List.nth s 0 with - | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ } - -> - v - | _ -> [] - -let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) = - List.fold ~init:1 dim ~f:(fun acc x -> - match x.value with - | `Dim_value v -> acc * Int64.to_int_exn v - | `Dim_param _ -> acc - | `not_set -> acc) - -let get_input_output_dim (model : Oprotom.t) = - let input_shape, output_shape = - match model.graph with - | Some g -> (get_nested_dims g.input, get_nested_dims g.output) - | _ -> ([], []) - in - (* TODO: here we only get the flattened dimension of inputs and outputs, but - more interesting parsing could be done later on. *) - let input_flat_dim = flattened_dim input_shape in - let output_flat_dim = flattened_dim output_shape in - (input_flat_dim, output_flat_dim) - -let produce_cfg (g : Oproto.Onnx.GraphProto.t) = - let open Oproto.Onnx in - let nodes = g.node - and inputs = g.input - and outputs = g.output - and initi = g.initializer' in - let fold_vip_names acc n = - match n.ValueInfoProto.name with - | Some str -> Some str :: acc - | None -> None :: acc - in - let i_nodes, o_nodes = - ( List.fold inputs ~init:[] ~f:fold_vip_names, - List.fold outputs ~init:[] ~f:fold_vip_names ) - and c_nodes = List.init (List.length nodes) ~f:(fun _ -> None) in - let fold_nodes_ops_cfg ns = - let get_node_operator_cfg x = - match x.NodeProto.op_type with - | None -> NCFG.Node.NO_OP - | Some o -> ( - match o with - | "Add" -> NCFG.Node.Add - | "Sub" -> NCFG.Node.Sub - | "Mul" -> NCFG.Node.Mul - | "Div" -> NCFG.Node.Div - | "Relu" -> NCFG.Node.ReLu - | "MatMul" -> NCFG.Node.Matmul - | "Gemm" -> NCFG.Node.Gemm - | "LogSoftmax" -> NCFG.Node.LogSoftmax - | "Transpose" -> NCFG.Node.Transpose - | "Squeeze" -> NCFG.Node.Squeeze - | "MaxPool" -> NCFG.Node.MaxPool - | "Constant" -> NCFG.Node.Constant - | "Conv" -> NCFG.Node.Conv - | "Reshape" -> NCFG.Node.Reshape - | "Flatten" -> NCFG.Node.Flatten - | "Identity" -> NCFG.Node.Identity - | "Gather" -> NCFG.Node.Gather - (* | "ReduceSum" -> NCFG.Node.ReduceSum | "GatherND" -> - NCFG.Node.GatherND | "RandomNormal" -> NCFG.Node.RandomNormal | "Abs" - -> NCFG.Node.Abs | "Log" -> NCFG.Node.Log *) - | _ -> raise (ParseError ("Unsupported ONNX operator " ^ o))) - in - List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns - in - let c_ops = List.rev @@ fold_nodes_ops_cfg nodes - and i_ops, o_ops = - ( List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length i_nodes), - List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length o_nodes) ) - in - let fold_nodes_attr ns = - let get_node_attr n = n.NodeProto.attribute in - List.fold ~f:(fun acc n -> get_node_attr n :: acc) ~init:[] ns - in - - let c_attr = List.rev @@ fold_nodes_attr nodes - and i_attr, o_attr = - ( List.init ~f:(fun _ -> [ no_attr ]) (List.length i_nodes), - List.init ~f:(fun _ -> [ no_attr ]) (List.length o_nodes) ) - in - let c_nodes_inputs, c_nodes_outputs = - List.unzip - @@ List.fold - ~f:(fun acc n -> (n.NodeProto.input, n.NodeProto.output) :: acc) - ~init:[] (List.rev nodes) - and i_nodes_inputs, i_nodes_outputs, o_nodes_inputs, o_nodes_outputs = - ( List.init ~f:(fun _ -> [ "NO_INPUT" ]) (List.length i_nodes), - List.init ~f:(fun _ -> [ "" ]) (List.length i_nodes), - List.init ~f:(fun _ -> [ "" ]) (List.length o_nodes), - List.init ~f:(fun _ -> [ "NO_OUTPUT" ]) (List.length o_nodes) ) - in - let data_dict = - let dict_tensors_cfg ts = - let get_float_from_index index data sh = - let index = Array.to_list index and sh = Array.to_list sh in - let pop_sh = List.tl_exn sh @ [ 1 ] in - (* Returns the factors by which multiply each coordinate *) - let rec get_factors_from_sh sh_f l = - match sh_f with - | [] -> List.rev l - | _ -> - get_factors_from_sh (List.tl_exn sh_f) - (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l) - in - let factors = get_factors_from_sh pop_sh [] in - let coord_in_data = - List.fold2_exn ~f:(fun x y z -> x + (y * z)) ~init:0 index factors - in - match data with - | Raw raw -> - let offset = 4 * coord_in_data in - (* Each float is coded on 4 bytes*) - let res = EndianBytes.LittleEndian.get_float raw offset in - res - | Float f -> List.nth_exn f coord_in_data - in - let build_tensor_from_data sh data = - let open NCFG.Tensor in - let sh = Array.of_list @@ sh in - let tensor = create sh K_float in - let coords = all_coords (get_shape tensor) in - let rec init_tensor t idx r = - match idx with - | x :: y -> - let value = - get_float_from_index (Array.of_list x) r (get_shape t) - in - set t (Array.of_list x) value; - init_tensor t y r - | [] -> t - in - init_tensor tensor coords data - in - let t_name x = - match x.TensorProto.name with Some n -> n | None -> "C_NODE" - in - let t_dim x = List.map ~f:Int64.to_int_exn x.TensorProto.dims in - let t_data x = - match x.TensorProto.raw_data with - | Some rd -> Some (build_tensor_from_data (t_dim x) (Raw rd)) - | None -> ( - match x.TensorProto.float_data with - | [] -> None - | f -> Some (build_tensor_from_data (t_dim x) (Float f))) - in - List.fold - ~f:(fun m x -> Map.add_exn m ~key:(t_name x) ~data:(t_data x)) - ~init:(Map.empty (module String)) - ts - in - dict_tensors_cfg initi - in - let unpack v = - match v with - | Some v -> v - | None -> failwith "Unpack found an unexpected None" - in - let tensor_list = - List.init - ~f:(fun i -> - match Map.find data_dict (unpack (List.nth_exn i_nodes i)) with - | Some v -> v - | None -> None) - (List.length i_nodes) - in - let tensor_list_full = Map.to_alist data_dict in - let tensor_list_rev = List.rev tensor_list in - let vip_dims v = - let val_t = - match v.ValueInfoProto.type' with - | Some t -> t - | None -> failwith "No type in value info" - in - let tns_t = - match val_t.TypeProto.value with - | `Tensor_type t -> t - | `not_set -> - failwith "No tensor type in value info" - (* TODO: support more tensor types *) - | _ -> raise (ParseError "Unknown tensor type") - in - let tns_s = - match tns_t.shape with - | Some s -> s - | None -> failwith "No tensor shape in value info" - in - List.rev - @@ List.fold tns_s ~init:[] ~f:(fun acc d -> - match d.value with - | `Dim_value d -> d :: acc - | `not_set | _ -> 0L :: acc) - in - - let c_tensordim_list = List.init (List.length nodes) ~f:(fun _ -> []) - and c_tensorraw_list = List.init (List.length nodes) ~f:(fun _ -> None) - and o_tensordim_list = - List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] outputs - and o_tensorraw_list = List.init (List.length o_nodes) ~f:(fun _ -> None) - and i_tensordim_list = - List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] inputs - and i_tensorraw_list = tensor_list_rev in - let nodes_names = i_nodes @ c_nodes @ o_nodes in - let ops = i_ops @ c_ops @ o_ops in - let attrs = i_attr @ c_attr @ o_attr in - let prevs_list = i_nodes_inputs @ c_nodes_inputs @ o_nodes_inputs in - let nexts_list = i_nodes_outputs @ c_nodes_outputs @ o_nodes_outputs in - let tensor_dims = - List.map - ~f:(List.map ~f:Int64.to_int_exn) - (i_tensordim_list @ c_tensordim_list @ o_tensordim_list) - in - let tensors = i_tensorraw_list @ c_tensorraw_list @ o_tensorraw_list in - let operator_parameters (attr : AttributeProto.t list) op = - match op with - | NCFG.Node.Transpose -> - let ints_params = - Array.map ~f:Int64.to_int_exn - @@ Array.of_list (List.nth_exn attr 0).ints - in - Some (NCFG.Node.Transpose_params ints_params) - (*TODO: maxpool and conv operators: match attr.name in attributes to - * create the correct value for each attribute*) - (* | NCFG.Vertex.MaxPool -> *) - (* | NCFG.Vertex.Conv -> *) - | _ -> None - in - let rec build_op_param_list attrs ops l = - match (attrs, ops) with - | a :: b, c :: d -> build_op_param_list b d (operator_parameters a c :: l) - | [], [] -> - List.rev l - (*All other list constructions are folding right, so we need to put a - final revert *) - | _ -> - raise (ParseError "Operator and attribute lists have not the same size") - in - let op_params_cfg = build_op_param_list attrs ops [] in - let cfg = G.init_cfg in - (* adding inputs, outputs and cnodes to the cfg *) - let unkerasize l = List.map ~f:(fun x -> if x = 0 then 1 else x) l in - for i = 0 to List.length nodes_names - 1 do - let (v : G.V.t) = - NCFG.Node.create ~id:i - ~name:(List.nth_exn nodes_names i) - ~sh:(Array.of_list @@ unkerasize (List.nth_exn tensor_dims i)) - ~op:(List.nth_exn ops i) - ~op_p:(List.nth_exn op_params_cfg i) - ~pred:(List.nth_exn prevs_list i) - ~succ:(List.nth_exn nexts_list i) - ~tensor:(List.nth_exn tensors i) - in - G.add_vertex cfg v - done; - (* Adding edges between vertices *) - (* For each unnamed vertex (= a calculation node) in the cfg ... *) - (* return true if l1 has at least one common element wih l2 *) - let rec shared_elm l1 l2 = - match l1 with - | x :: y -> List.mem l2 x ~equal:String.equal || shared_elm y l2 - | [] -> false - in - List.iter - ~f:(fun (v : G.V.t) -> - match v.name with - | None -> - let pred = v.pred and succ = v.succ in - let prev_v = - (* ... get all vertices in cfg that have the current vertex preds - * in their succ (at least one of their succ is inside our preds )*) - G.find_vertices cfg (fun (x : G.V.t) -> - if shared_elm pred x.succ then true else false) - (* ... get all vertices in cfg that have the current vertex preds - * in their name (they are named the same as one of our preds )*) - and named_pred = - G.find_vertices cfg (fun (x : G.V.t) -> - match x.name with - | Some name -> if shared_elm pred [ name ] then true else false - | None -> false) - (* ... get all vertices in cfg that have the current vertex succ - * in their name (they are named the same as one of our succs )*) - and named_succ = - G.find_vertices cfg (fun (x : G.V.t) -> - match x.name with - | Some name -> if shared_elm succ [ name ] then true else false - | None -> false) - (* get all vertices in cfg that have the current vertex succs - * in their preds (at least one of their preds is inside our succ )*) - and next_v = - G.find_vertices cfg (fun (x : G.V.t) -> - if shared_elm succ x.pred then true else false) - in - (* add edges between current vertex and identified preds and succs *) - let v_predecessors = prev_v @ named_pred - and v_successors = next_v @ named_succ in - let unpack_tname (x : G.V.t) = - match x.NCFG.Node.name with Some n -> n | None -> "" - in - List.iter - ~f:(fun (x : G.V.t) -> - let label = - match List.nth x.succ 0 with - | Some "NO_OUTPUT" -> - let pred_name = unpack_tname x in - if List.mem ~equal:String.equal v.NCFG.Node.pred pred_name - then pred_name - else "" - | Some l -> l - | None -> "" - in - G.add_edge_e cfg (x, label, v)) - v_predecessors; - (* add successors edges after filtering those - * that are already an edge*) - List.iter - ~f:(fun (x : G.V.t) -> - let all_preds = G.preds cfg x and all_succs = G.succs cfg x in - if List.mem ~equal:NCFG.Node.equal all_preds v - || List.mem ~equal:NCFG.Node.equal all_succs v - then () - else - let label = - match List.nth_exn x.pred 0 with - | "NO_INPUT" -> - let succ_name = unpack_tname x in - if List.mem ~equal:String.equal v.NCFG.Node.succ succ_name - then succ_name - else "" - | l -> l - in - G.add_edge_e cfg (v, label, x)) - v_successors - | _ -> ()) - (G.vertex_list cfg); - - (*rationale of the following: - * PyTorch stores network nodes in the field "inputs" of - * the ONNX graph g, and network parameters as a list of tensors - * in the ONNX initializer_. - * To make the two correspond, elements of g.inputs and g.initializer_ - * share the same value in the field "name". - * In Keras, elements of g.initializer_ have a name, but they do not - * correspond to any name in g.inputs. - * What we did before was then to create the actual nier cfg following the - * PyTorch way. - * Below, we complete the cfg with keras data by doing the following: - * * create a node for NIER for each tensor in onnx initializer_ - * * for each NIER node, check if there is a node sharing the same name - * pred - * * if yes, remove the one with highest ID (those are initi nodes, but since - * there is already a node in CFG with this name we do not - * need those) - * * if not, for each NIER node, chck if there is a node - * which name is contained in prevs. add it to the prev - * *) - - (* adding initi vertices to the cfg *) - for i = 0 to List.length tensor_list_full - 1 do - let shape = - match snd (List.nth_exn tensor_list_full i) with - | Some t -> unkerasize (Array.to_list @@ NCFG.Tensor.get_shape t) - | None -> [] - in - let (v : G.V.t) = - NCFG.Node.create - ~id:(i + List.length nodes_names) - ~name:(Some (fst (List.nth_exn tensor_list_full i))) - ~sh:(Array.of_list @@ unkerasize shape) - ~op:NO_OP ~op_p:None ~pred:[] ~succ:[] - ~tensor:(snd (List.nth_exn tensor_list_full i)) - in - G.add_vertex cfg v - done; - (* build a list of nodes - * sharing name but with different ids *) - let same_name_diff_ids = - let aux (x : G.V.t) = - G.fold_vertex - (fun y acc -> - match (x.name, y.name) with - | Some xa, Some ya -> - if (not (y.id = x.id)) && String.equal xa ya - then (x, y) :: acc - else acc - | _ -> acc) - cfg [] - in - G.fold_vertex (fun x l -> aux x :: l) cfg [] - in - let highest_ids = - List.fold - ~f:(fun acc x -> - match x with - | a :: _ -> - let maxval = max (fst a).NCFG.Node.id (snd a).NCFG.Node.id in - maxval :: acc - | [] -> acc) - ~init:[] same_name_diff_ids - in - (* (* removing nodes with highest id, those are the*) (* * ones we just added - *)*) - List.iter - ~f:(fun x -> - match x with - | l :: _ -> - let v1 = fst l in - if List.mem ~equal:( = ) highest_ids v1.NCFG.Node.id - then - (* Printf.printf "Removing id %d \n%!" *) - (* v1.NCFG.Vertex.id; *) - G.remove_vertex cfg v1 - else () - | [] -> ()) - same_name_diff_ids; - (* Now it is Keras time. - * Look for nodes sharing name and preds, - * then create edge *) - let shared_name_preds = - let aux (x : G.V.t) = - match x.name with - (* look in other vertices if name is among - * predecessors *) - | Some n -> G.find_vertices cfg (fun x -> shared_elm [ n ] x.pred) - | None -> [] - in - G.fold_vertex (fun x l -> (x, aux x) :: l) cfg [] - in - List.iter - ~f:(fun x -> - let orgn = fst x and to_edge = snd x in - List.iter - ~f:(fun t -> - if not (G.mem_edge cfg orgn t) - then G.add_edge_e cfg (orgn, unpack orgn.NCFG.Node.name, t) - else ()) - to_edge) - shared_name_preds; - (* else (); *) - cfg - -let nier_of_onnx_protoc (model : Oprotom.t) = - match model.graph with - | Some g -> produce_cfg g - | None -> raise (ParseError "No graph in ONNX input file found") - -let default_opset_info = - let open Oproto.Onnx in - let onnx_domain = "" in - OperatorSetIdProto.make ~domain:onnx_domain ~version:8L () - -let nier_to_onnx_protoc nier = - (* TODO: get tensor data, and operator params *) - let vertices = G.vertex_list nier in - let open NCFG.Node in - let protocs = - (* match on names of NO_OP nodes and add their outputs to corresponding - * C_NODEs inputs *) - let vertex_to_protoc v = - let name = get_name v in - let input, output = (get_pred_list v, get_succ_list v) in - let node, initi = - match get_op v with - | NO_OP | RW_Linearized_ReLu -> - (* ONNX initializers are named ONNX Tensor. - * If an initializer's name matches an existing - * ONNX node input name, the initializer will be assigned as - * the input of the node. *) - let initi = - match get_tensor v with - | None -> None - | Some t -> - Some - (Oproto.Onnx.TensorProto.make ~data_type:1 - ~dims: - (Array.to_list @@ Array.map ~f:Int64.of_int - @@ NCFG.Tensor.get_shape t) - ~float_data:(NCFG.Tensor.flatten t) ~name ()) - in - let node = None in - (node, initi) - | _ -> - let op_type = str_op (get_op v) in - let attribute = - match v.operator_parameters with - | None | Some (RW_Linearized_ReLu_params _) -> [] - | Some - (Pool_params - (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d))) - | Some - (Conv_params - (Ksize k, Some (Stride s), Some (Pads p), Some (Dilations d))) - -> - let ksize = - Oproto.Onnx.AttributeProto.make ~name:"ksize" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int k) - () - in - let stride = - Oproto.Onnx.AttributeProto.make ~name:"stride" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s) - () - in - let pads = - Oproto.Onnx.AttributeProto.make ~name:"pads" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int p) - () - in - let dilations = - Oproto.Onnx.AttributeProto.make ~name:"dilations" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int d) - () - in - [ ksize; stride; pads; dilations ] - | Some (Transpose_params s) -> - [ - Oproto.Onnx.AttributeProto.make ~name:"perms" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s) - (); - ] - | Some (Gather_params a) -> - [ - Oproto.Onnx.AttributeProto.make ~name:"axis" - ~ints:[ Int64.of_int a ] - (); - ] - | Some (ReduceSum_params (a, b)) -> - [ - Oproto.Onnx.AttributeProto.make ~name:"keepdims" - ~i:(Int64.of_int a) (); - Oproto.Onnx.AttributeProto.make ~name:"noop_with_empty_axes" - ~i:(Int64.of_int b) (); - ] - | Some (RandomNormal_params (a, b, c, d, s)) -> - [ - Oproto.Onnx.AttributeProto.make ~name:"dtype" - ~i:(Int64.of_int a) (); - Oproto.Onnx.AttributeProto.make ~name:"mean" ~f:b (); - Oproto.Onnx.AttributeProto.make ~name:"scale" ~f:c (); - Oproto.Onnx.AttributeProto.make ~name:"seed" ~f:d (); - Oproto.Onnx.AttributeProto.make ~name:"shape" - ~ints:(Array.to_list @@ Array.map ~f:Int64.of_int s) - (); - ] - | _ -> [] - in - let node = - Some - (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type - ~attribute ~doc_string:"" ()) - in - let initi = None in - (node, initi) - in - (node, initi) - in - List.fold ~init:([], []) - ~f:(fun (accn, acci) v -> - let node, initi = vertex_to_protoc v in - match (node, initi) with - | Some n, Some t -> (n :: accn, t :: acci) - | Some n, None -> (n :: accn, acci) - | None, Some t -> (accn, t :: acci) - | None, None -> (accn, acci)) - vertices - in - let docstr = - "This ONNX model was generated from the Neural Intermediate Representation \ - of CAISAR" - in - let protog = - Oproto.Onnx.GraphProto.make ~name:"ONNX CAISAR Export" ~node:(fst protocs) - ~initializer':(snd protocs) ~sparse_initializer:[] - ~doc_string:"ONNX graph generated from CAISAR NIER" ~input:[] ~output:[] - ~value_info:[] ~quantization_annotation:[] () - in - let protom = - Oproto.Onnx.ModelProto.make ~ir_version:8L - ~opset_import:[ default_opset_info ] ~producer_name:"CAISAR" - ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr - ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] () - in - let writer = Oprotom.to_proto protom in - Ocaml_protoc_plugin.Writer.contents writer - -let write_nier_to_onnx nier out_channel = - let onnx = nier_to_onnx_protoc nier in - Stdio.Out_channel.output_string out_channel onnx - -let parse_in_channel in_channel = - let open Result in - try - let buf = Stdio.In_channel.input_all in_channel in - let reader = Ocaml_protoc_plugin.Reader.create buf in - match Oprotom.from_proto reader with - | Ok r -> - let n_inputs, n_outputs = get_input_output_dim r in - let nier = - try Ok (nier_of_onnx_protoc r) with - | ParseError s | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) - in - Ok { n_inputs; n_outputs; nier } - | _ -> Error "Cannot read protobuf" - with - | Sys_error s -> Error s - | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) - -let parse filename = - let in_channel = Stdlib.open_in filename in - Fun.protect - ~finally:(fun () -> Stdlib.close_in in_channel) - (fun () -> parse_in_channel in_channel) - -let write nier filename = - let out_chan = Stdlib.open_out filename in - write_nier_to_onnx nier out_chan; - Stdlib.close_out out_chan - -module Simple = Simple diff --git a/lib/onnx/simple.ml b/lib/onnx/reader.ml similarity index 55% rename from lib/onnx/simple.ml rename to lib/onnx/reader.ml index 268ac5469fe49e3a76f6206fdd669f2cc5683dc7..739c0c7663655e8baa0cc94b7cd7f4b49360d223 100644 --- a/lib/onnx/simple.ml +++ b/lib/onnx/reader.ml @@ -25,24 +25,22 @@ module Format = Stdlib.Format module Fun = Stdlib.Fun module Oproto = Onnx_protoc (* Autogenerated during compilation *) module Oprotom = Oproto.Onnx.ModelProto -module NCFG = Ir.Nier_simple -module G = NCFG.GFloat exception ParseError of string type t = { n_inputs : int; (* Number of inputs. *) n_outputs : int; (* Number of outputs. *) - nier : (G.t, string) Result.t; (* Intermediate representation. *) + nir : (Nir.Ngraph.t, string) Result.t; (* Intermediate representation. *) } (* ONNX format handling. *) module Convert : sig - val nier_of_onnx_protoc : Oproto.Onnx.ModelProto.t -> G.t + val nir_of_onnx_protoc : Oproto.Onnx.ModelProto.t -> Nir.Ngraph.t val get_input_output_dim : Oproto.Onnx.ModelProto.t -> int * int end = struct let get_shape_of_dims (s : Oproto.Onnx.TensorShapeProto.t) = - Ir.Nier_simple.Shape.of_list + Nir.Shape.of_list @@ List.map s ~f:(function | { value = `Dim_value v; _ } -> Int64.to_int_exn v | { value = `Dim_param _; _ } -> failwith "Parameteric shape" @@ -62,7 +60,7 @@ end = struct | _ -> [] let flattened_dim (s : Oproto.Onnx.TensorShapeProto.Dimension.t list) = - Ir.Nier_simple.Shape.size (get_shape_of_dims s) + Nir.Shape.size (get_shape_of_dims s) let get_input_output_dim (model : Oprotom.t) = let input_shape, output_shape = @@ -76,13 +74,10 @@ end = struct let output_flat_dim = flattened_dim output_shape in (input_flat_dim, output_flat_dim) - let convert_tensor (ts : Oproto.Onnx.TensorProto.t) : - Ir.Nier_simple.GenTensor.t = + let convert_tensor (ts : Oproto.Onnx.TensorProto.t) : Nir.Gentensor.t = let open Oproto.Onnx in - let dims = - Ir.Nier_simple.Shape.of_list @@ List.map ~f:Int64.to_int_exn ts.dims - in - let size = Ir.Nier_simple.Shape.size dims in + let dims = Nir.Shape.of_list @@ List.map ~f:Int64.to_int_exn ts.dims in + let size = Nir.Shape.size dims in let read_raw ~get kind = match ts.raw_data with | None -> @@ -93,7 +88,7 @@ end = struct let v = get data i in Bigarray.Array1.set t i v done; - Ir.Nier_simple.Tensor.of_array1 dims t + Nir.Tensor.of_array1 dims t in let read_gen ~get elt_size kind custom_data = match custom_data with @@ -106,7 +101,7 @@ end = struct | l -> let t = Bigarray.(Array1.create kind c_layout size) in List.iteri l ~f:(fun i f -> Bigarray.Array1.set t i f); - Ir.Nier_simple.Tensor.of_array1 dims t + Nir.Tensor.of_array1 dims t in match Option.map ~f:TensorProto.DataType.from_int ts.data_type with | None -> failwith "TensorProto should have a type" @@ -173,7 +168,7 @@ end = struct | _ -> failwith "graph with more than one input node (unsupported)" in Hashtbl.add_exn converted ~key:input_name - ~data:(Ir.Nier_simple.Node.create (Input { shape = input_shape })); + ~data:(Nir.Node.create (Input { shape = input_shape })); (* converter *) let rec convert output = Hashtbl.findi_or_add ~default:convert_aux converted output @@ -194,50 +189,37 @@ end = struct (module String) (List.map ~f:(fun a -> (Option.value_exn a.name, a)) n.attribute) in - let get_attr ?default name m = - match Hashtbl.find attrs name with - | Some v -> m v - | None -> ( - match default with - | Some v -> v - | None -> Fmt.failwith "Required attribute %s missing" name) - in - let get_float ?default name : float = - get_attr ?default name (function - | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } - -> - f - | _ -> failwith "Attribute wrongly typed") + let get_float name : float = + match Hashtbl.find_exn attrs name with + | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } + -> + f + | _ -> failwith "Attribute wrongly typed" in - let get_int ?default name : int = - get_attr ?default name (function - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } - -> - Int64.to_int_exn i - | _ -> failwith "Attribute wrongly typed") + let get_int name : int = + match Hashtbl.find_exn attrs name with + | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> + Int64.to_int_exn i + | _ -> failwith "Attribute wrongly typed" in - let get_ints ?default name : int list = - get_attr ?default name (function - | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> - List.map ~f:Int64.to_int_exn l - | _ -> failwith "Attribute wrongly typed") + let get_ints name : int list = + match Hashtbl.find_exn attrs name with + | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> + List.map ~f:Int64.to_int_exn l + | _ -> failwith "Attribute wrongly typed" in - let get_bool ?default name : bool = - get_attr ?default name (function - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } - -> - not (Int64.equal i 0L) - | _ -> failwith "Attribute wrongly typed") + let get_bool name : bool = + match Hashtbl.find_exn attrs name with + | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> + not (Int64.equal i 0L) + | _ -> failwith "Attribute wrongly typed" in - let get_tensor ?default name : Ir.Nier_simple.GenTensor.t = - get_attr ?default name (function - | { - type' = Some AttributeProto.AttributeType.TENSOR; - t = Some t; - _; - } -> - convert_tensor t - | _ -> failwith "Attribute wrongly typed") + let get_tensor name : Nir.Gentensor.t = + match Hashtbl.find_exn attrs name with + | { type' = Some AttributeProto.AttributeType.TENSOR; t = Some t; _ } + -> + convert_tensor t + | _ -> failwith "Attribute wrongly typed" in let n' = match n.op_type with @@ -246,26 +228,22 @@ end = struct match s with | "Add" -> let input1, input2 = two_arg n.input in - Ir.Nier_simple.Add - { input1 = convert input1; input2 = convert input2 } + Nir.Node.Add { input1 = convert input1; input2 = convert input2 } | "Sub" -> let input1, input2 = two_arg n.input in - Ir.Nier_simple.Sub - { input1 = convert input1; input2 = convert input2 } + Nir.Node.Sub { input1 = convert input1; input2 = convert input2 } | "Mul" -> let input1, input2 = two_arg n.input in - Ir.Nier_simple.Mul - { input1 = convert input1; input2 = convert input2 } + Nir.Node.Mul { input1 = convert input1; input2 = convert input2 } | "Div" -> let input1, input2 = two_arg n.input in - Ir.Nier_simple.Div - { input1 = convert input1; input2 = convert input2 } + Nir.Node.Div { input1 = convert input1; input2 = convert input2 } | "Relu" -> let input1 = one_arg n.input in - Ir.Nier_simple.ReLu { input = convert input1 } + Nir.Node.ReLu { input = convert input1 } | "MatMul" -> let input1, input2 = two_arg n.input in - Ir.Nier_simple.Matmul + Nir.Node.Matmul { input1 = convert input1; input2 = convert input2 } | "Gemm" -> let inputA, inputB, inputC = @@ -274,19 +252,19 @@ end = struct | [ inputA; inputB; inputC ] -> (inputA, inputB, Some inputC) | _ -> failwith "Gemm must have 2 or 3 inputs" in - Ir.Nier_simple.Gemm + Nir.Node.Gemm { inputA = convert inputA; inputB = convert inputB; inputC = Option.map ~f:convert inputC; - alpha = get_float ~default:1.0 "alpha"; - beta = get_float ~default:1.0 "beta"; - transA = get_bool ~default:false "transA"; - transB = get_bool ~default:false "transB"; + alpha = get_float "alpha"; + beta = get_float "beta"; + transA = get_bool "transA"; + transB = get_bool "transB"; } - | "LogSoftmax" -> Ir.Nier_simple.LogSoftmax + | "LogSoftmax" -> Nir.Node.LogSoftmax | "Transpose" -> - Ir.Nier_simple.Transpose + Nir.Node.Transpose { input = convert (one_arg n.input); perm = get_ints "perm" } | "Squeeze" -> let data, axes = @@ -295,7 +273,7 @@ end = struct | [ data; axes ] -> (convert data, Some (convert axes)) | _ -> failwith "Squeeze must have 1 or 2 inputs" in - Ir.Nier_simple.Squeeze { data; axes } + Nir.Node.Squeeze { data; axes } | "MaxPool" -> MaxPool | "Constant" -> Constant { data = get_tensor "value" } | "Conv" -> Conv @@ -304,10 +282,10 @@ end = struct { input = convert @@ one_arg n.input; axis = get_int "axis" } (* | "Reshape" -> NCFG.Node.Reshape | "Identity" -> NCFG.Node.Identity | "Gather" -> NCFG.Node.Gather *) - | "Abs" -> Ir.Nier_simple.Abs { input = convert @@ one_arg n.input } - | "Log" -> Ir.Nier_simple.Log { input = convert @@ one_arg n.input } + | "Abs" -> Nir.Node.Abs { input = convert @@ one_arg n.input } + | "Log" -> Nir.Node.Log { input = convert @@ one_arg n.input } | "RandomNormal" -> - Ir.Nier_simple.RandomNormal + Nir.Node.RandomNormal { dtype = get_int "dtype"; mean = get_float "mean"; @@ -318,15 +296,14 @@ end = struct (* TODO: ReduceSum, GatherND *) | s -> failwith (Printf.sprintf "Unknown operators %s" s)) in - Ir.Nier_simple.Node.create n' - | Tensor t -> - Ir.Nier_simple.Node.create (Constant { data = convert_tensor t }) + Nir.Node.create n' + | Tensor t -> Nir.Node.create (Constant { data = convert_tensor t }) in let output' = convert output in - assert (Ir.Nier_simple.Shape.equal output'.shape output_shape); - Ir.Nier_simple.create output' + assert (Nir.Shape.equal output'.shape output_shape); + Nir.Ngraph.create output' - let nier_of_onnx_protoc (model : Oprotom.t) = + let nir_of_onnx_protoc (model : Oprotom.t) = (match model.ir_version with | None -> failwith "IR version not specified" | Some (3L | 4L | 5L | 6L | 7L | 8L) -> () @@ -351,151 +328,19 @@ let parse_in_channel in_channel = match Oprotom.from_proto reader with | Ok r -> let n_inputs, n_outputs = Convert.get_input_output_dim r in - let nier = - try Ok (Convert.nier_of_onnx_protoc r) with + let nir = + try Ok (Convert.nir_of_onnx_protoc r) with | ParseError s | Sys_error s -> Error s | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) in - Ok { n_inputs; n_outputs; nier } + Ok { n_inputs; n_outputs; nir } | _ -> Error "Cannot read protobuf" with | Sys_error s -> Error s | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) -let parse filename = +let from_file filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) (fun () -> parse_in_channel in_channel) - -let value_info_from_tensor_shape ~name (n : Ir.Nier_simple.Node.t) = - let open Oproto.Onnx in - let dim = - List.map (Ir.Nier_simple.Shape.to_list n.shape) ~f:(fun i -> - TensorShapeProto.Dimension.make ~value:(`Dim_value (Int64.of_int i)) ()) - in - let shape = TensorShapeProto.make ~dim () in - let ty : TensorProto.DataType.t = - match n.ty with Float -> FLOAT | Int64 -> INT64 - in - let value = - `Tensor_type - (TypeProto.Tensor.make - ~elem_type:TensorProto.DataType.(to_int ty) - ~shape ()) - in - let type' = TypeProto.make ~value () in - let value_info = ValueInfoProto.make ~name ~type' () in - value_info - -let convert_into_tensor ?name (t : Ir.Nier_simple.GenTensor.t) = - let mk data_type = - Oproto.Onnx.TensorProto.make - ~data_type:Oproto.Onnx.TensorProto.DataType.(to_int data_type) - ~dims: - (List.map ~f:Int64.of_int @@ Ir.Nier_simple.Shape.to_list - @@ Ir.Nier_simple.GenTensor.shape t) - ?name - in - match t with - | Float t -> mk FLOAT ~float_data:(Ir.Nier_simple.Tensor.flatten t) () - | Int64 t -> mk INT64 ~int64_data:(Ir.Nier_simple.Tensor.flatten t) () - -let default_opset_import = - let open Oproto.Onnx in - let onnx_domain = "" in - OperatorSetIdProto.make ~domain:onnx_domain ~version:13L () - -let nier_simple_to_onnx_protoc (nier_simple : Ir.Nier_simple.GFloat.t) = - let open Oproto.Onnx in - let get_name (v : Ir.Nier_simple.Node.t) = Int.to_string v.id in - let protocs, input = - let acc = Queue.create () in - let g_input = ref None in - let vertex_to_protoc (v : Ir.Nier_simple.Node.t) = - let name = get_name v in - let input = List.map ~f:get_name (Ir.Nier_simple.Node.preds v) in - let output = [ name ] in - let make op_type attribute = - Queue.enqueue acc - (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~attribute - ~doc_string:"" ()) - in - let mk_int name i = - AttributeProto.make ~name ~type':INT ~i:(Int64.of_int i) () - in - let mk_ints name ints = - AttributeProto.make ~name ~type':INTS - ~ints:(List.map ~f:Int64.of_int ints) - () - in - let mk_float name f = AttributeProto.make ~name ~type':FLOAT ~f () in - let mk_tensor name t = AttributeProto.make ~name ~type':TENSOR ~t () in - match v.descr with - | Gemm _ | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv - | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ -> - failwith (Fmt.str "Not implemented export: %a" Ir.Nier_simple.Node.pp v) - | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ] - | Reshape _ -> make "Reshape" [] - | Constant { data } -> - let data = convert_into_tensor data in - make "Constant" [ mk_tensor "value" data ] - | Add _ -> make "Add" [] - | Sub _ -> make "Sub" [] - | Mul _ -> make "Mul" [] - | Div _ -> make "Div" [] - | Matmul _ -> make "MatMul" [] - | ReLu _ -> make "Relu" [] - | Input _ -> g_input := Some v - | Concat { axis; _ } -> make "Concat" [ mk_int "axis" axis ] - | Gather { axis; _ } -> make "Gather" [ mk_int "axis" axis ] - | Abs _ -> make "Abs" [] - | Log _ -> make "Log" [] - | RandomNormal { dtype; mean; scale; seed; shape } -> - make "RandomNormal" - [ - mk_int "dtype" dtype; - mk_float "mean" mean; - mk_float "scale" scale; - mk_float "seed" seed; - mk_ints "shape" (Array.to_list shape); - ] - in - Ir.Nier_simple.iter_vertex vertex_to_protoc nier_simple; - (Queue.to_list acc, Option.value_exn !g_input) - in - let docstr = - "This ONNX model was generated from the Neural Intermediate Representation \ - of CAISAR" - in - let input = [ value_info_from_tensor_shape ~name:(get_name input) input ] in - let output = - let output = Ir.Nier_simple.output nier_simple in - [ value_info_from_tensor_shape ~name:(get_name output) output ] - in - let value_info = - List.map (Ir.Nier_simple.nodes nier_simple) ~f:(fun v -> - value_info_from_tensor_shape ~name:(get_name v) v) - in - let protog = - GraphProto.make ~name:"ONNX CAISAR Export" ~node:protocs ~initializer':[] - ~sparse_initializer:[] ~doc_string:"ONNX graph generated from CAISAR NIER" - ~input ~output ~value_info ~quantization_annotation:[] () - in - let protom = - Oproto.Onnx.ModelProto.make ~ir_version:8L - ~opset_import:[ default_opset_import ] ~producer_name:"CAISAR" - ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr - ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] () - in - let writer = Oprotom.to_proto protom in - Ocaml_protoc_plugin.Writer.contents writer - -let write_to_onnx nier out_channel = - let onnx = nier_simple_to_onnx_protoc nier in - Stdio.Out_channel.output_string out_channel onnx - -let write nier filename = - let out_chan = Stdlib.open_out filename in - write_to_onnx nier out_chan; - Stdlib.close_out out_chan diff --git a/lib/onnx/simple.mli b/lib/onnx/reader.mli similarity index 86% rename from lib/onnx/simple.mli rename to lib/onnx/reader.mli index f2a0528ea8843bb93b8038406063d98145242c67..18e8ddb1925b31480a1498b27f238de049a3c1f1 100644 --- a/lib/onnx/simple.mli +++ b/lib/onnx/reader.mli @@ -23,13 +23,9 @@ type t = private { n_inputs : int; (** Number of inputs. *) n_outputs : int; (** Number of outputs. *) - nier : (Ir.Nier_simple.GFloat.t, string) Result.t; - (** Intermediate representation. *) + nir : (Nir.Ngraph.t, string) Result.t; (** Intermediate representation. *) } (** ONNX model metadata and intermediate representation. *) -val parse : string -> (t, string) Result.t -(** Parse an ONNX file into a NIER. *) - -val write : Ir.Nier_simple.GFloat.t -> string -> unit -(** Write a NIER into an ONNX file. *) +val from_file : string -> (t, string) Result.t +(** Parse an ONNX file. *) diff --git a/lib/onnx/tests/print.ml b/lib/onnx/tests/print.ml index 12cd6e4f87f5013bcbb53c41f536de3095f07e82..166927b701182cfb950ea25e547757824c856886 100644 --- a/lib/onnx/tests/print.ml +++ b/lib/onnx/tests/print.ml @@ -9,14 +9,14 @@ let temporary_file = "out/test.onnx" let () = match Onnx.Simple.parse file with | Error s -> print_endline s - | Ok { nier = Error s; _ } -> print_endline s - | Ok { nier = Ok g; _ } -> ( + | Ok { nir = Error s; _ } -> print_endline s + | Ok { nir = Ok g; _ } -> ( print_endline "ok"; Onnx.Simple.write g temporary_file; match Onnx.Simple.parse temporary_file with | Error s -> print_endline s - | Ok { nier = Error s; _ } -> print_endline s - | Ok { nier = Ok _; _ } -> print_endline "ok") + | Ok { nir = Error s; _ } -> print_endline s + | Ok { nir = Ok _; _ } -> print_endline "ok") let () = let pid = diff --git a/lib/onnx/writer.ml b/lib/onnx/writer.ml new file mode 100644 index 0000000000000000000000000000000000000000..6b8d6f3eff756c70a96d77dcdc5deea99e86e0fe --- /dev/null +++ b/lib/onnx/writer.ml @@ -0,0 +1,157 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Base +module Format = Stdlib.Format +module Fun = Stdlib.Fun +module Oproto = Onnx_protoc (* Autogenerated during compilation *) +module Oprotom = Oproto.Onnx.ModelProto + +let value_info_from_tensor_shape ~name ~shape = + let open Oproto.Onnx in + let dim = + List.map (Nir.Shape.to_list shape) ~f:(fun i -> + TensorShapeProto.Dimension.make ~value:(`Dim_value (Int64.of_int i)) ()) + in + let shape = TensorShapeProto.make ~dim () in + let value = + `Tensor_type + (TypeProto.Tensor.make + ~elem_type:AttributeProto.AttributeType.(to_int FLOAT) + ~shape ()) + in + let type' = TypeProto.make ~value () in + let value_info = ValueInfoProto.make ~name ~type' () in + value_info + +let convert_into_tensor ?name (t : Nir.Gentensor.t) = + let mk data_type = + Oproto.Onnx.TensorProto.make + ~data_type:Oproto.Onnx.TensorProto.DataType.(to_int data_type) + ~dims: + (List.map ~f:Int64.of_int @@ Nir.Shape.to_list @@ Nir.Gentensor.shape t) + ?name + in + match t with + | Float t -> mk FLOAT ~float_data:(Nir.Tensor.flatten t) () + | Int64 t -> mk INT64 ~int64_data:(Nir.Tensor.flatten t) () + +let default_opset_import = + let open Oproto.Onnx in + let onnx_domain = "" in + OperatorSetIdProto.make ~domain:onnx_domain ~version:13L () + +let nir_to_onnx_protoc (nir : Nir.Ngraph.t) = + let open Oproto.Onnx in + let get_name (v : Nir.Node.t) = Int.to_string v.id in + let protocs, (input, input_shape) = + let acc = Queue.create () in + let g_input = ref None in + let vertex_to_protoc (v : Nir.Node.t) = + let name = get_name v in + let input = List.map ~f:get_name (Nir.Node.preds v) in + let output = [ name ] in + let make op_type attribute = + Queue.enqueue acc + (Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~attribute + ~doc_string:"" ()) + in + let mk_int name i = + AttributeProto.make ~name ~type':INT ~i:(Int64.of_int i) () + in + let mk_ints name ints = + AttributeProto.make ~name ~type':INTS + ~ints:(List.map ~f:Int64.of_int ints) + () + in + let mk_float name f = AttributeProto.make ~name ~type':FLOAT ~f () in + let mk_tensor name t = AttributeProto.make ~name ~type':TENSOR ~t () in + match v.descr with + | Gemm _ | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv + | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ -> + Caisar_logging.Logging.not_implemented_yet (fun m -> + m "Operator %a not implemented yet." Nir.Node.pp_descr v.descr) + | Reshape _ -> make "Reshape" [] + | Flatten { axis; _ } -> make "Flatten" [ mk_int "axis" axis ] + | Constant { data } -> + let data = convert_into_tensor data in + make "Constant" [ mk_tensor "value" data ] + | Add _ -> make "Add" [] + | Sub _ -> make "Sub" [] + | Mul _ -> make "Mul" [] + | Div _ -> make "Div" [] + | Matmul _ -> make "MatMul" [] + | ReLu _ -> make "Relu" [] + | Input { shape } -> g_input := Some (v, shape) + | Concat { axis; _ } -> make "Concat" [ mk_int "axis" axis ] + | Gather { axis; _ } -> make "Gather" [ mk_int "axis" axis ] + | Abs _ -> make "Abs" [] + | Log _ -> make "Log" [] + | RandomNormal { dtype; mean; scale; seed; shape } -> + make "RandomNormal" + [ + mk_int "dtype" dtype; + mk_float "mean" mean; + mk_float "scale" scale; + mk_float "seed" seed; + mk_ints "shape" (Array.to_list shape); + ] + in + Nir.Ngraph.iter_vertex vertex_to_protoc nir; + (Queue.to_list acc, Option.value_exn !g_input) + in + let docstr = + "This ONNX model was generated from the Neural Intermediate Representation \ + of CAISAR" + in + let input = + [ value_info_from_tensor_shape ~name:(get_name input) ~shape:input_shape ] + in + let output = + let output = Nir.Ngraph.output nir in + [ + value_info_from_tensor_shape ~name:(get_name output) + ~shape:(Nir.Node.compute_shape output); + ] + in + let protog = + GraphProto.make ~name:"ONNX CAISAR Export" ~node:protocs ~initializer':[] + ~sparse_initializer:[] ~doc_string:"ONNX graph generated from CAISAR NIR" + ~input ~output ~value_info:[] ~quantization_annotation:[] () + in + let protom = + Oproto.Onnx.ModelProto.make ~ir_version:8L + ~opset_import:[ default_opset_import ] ~producer_name:"CAISAR" + ~producer_version:"1.0" ~domain:"" ~model_version:(-1L) ~doc_string:docstr + ~graph:protog ~metadata_props:[] ~training_info:[] ~functions:[] () + in + let writer = Oprotom.to_proto protom in + Ocaml_protoc_plugin.Writer.contents writer + +let write_to_onnx nir out_channel = + let onnx = nir_to_onnx_protoc nir in + Stdio.Out_channel.output_string out_channel onnx + +let to_file nir filename = + let out_chan = Stdlib.open_out filename in + write_to_onnx nir out_chan; + Stdlib.close_out out_chan diff --git a/lib/onnx/writer.mli b/lib/onnx/writer.mli new file mode 100644 index 0000000000000000000000000000000000000000..56945117a968f34ac5998025664d9f31375d5fa8 --- /dev/null +++ b/lib/onnx/writer.mli @@ -0,0 +1,24 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2023 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +val to_file : Nir.Ngraph.t -> string -> unit +(** Write a NIR into an ONNX file. *) diff --git a/src/dune b/src/dune index ab868447a863fc15039b8c8b72eae6e0bbefb0cc..3afbd8c904c6c573cbaf7bb2378a4875921f17cb 100644 --- a/src/dune +++ b/src/dune @@ -24,6 +24,7 @@ fpath zarith caisar_logging + caisar.nir caisar.xgboost) (preprocess (pps diff --git a/src/language.ml b/src/language.ml index 2e9728c2cbb055985e5ee01bccbf0ced7537c386..960216d13db020663c574df469d5a83fcb12522e 100644 --- a/src/language.ml +++ b/src/language.ml @@ -32,7 +32,7 @@ type nn_shape = { nb_outputs : int; ty_data : Ty.ty; filename : string; - nier : Ir.Nier_simple.t option; + nir : Nir.Ngraph.t option; } type svm_shape = { @@ -46,7 +46,7 @@ let loaded_svms = Term.Hls.create 10 let lookup_loaded_nets = Term.Hls.find_opt loaded_nets let lookup_loaded_svms = Term.Hls.find_opt loaded_svms -let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr = +let register_nn_as_tuple env nb_inputs nb_outputs filename ?nir mstr = let name = "AsTuple" in let th_uc = Pmodule.create_module env (Ident.id_fresh name) in let nn = Pmodule.read_module env [ "caisar"; "caisar" ] "NN" in @@ -62,14 +62,14 @@ let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr = (Ty.ty_tuple (List.init nb_outputs ~f)) in Term.Hls.add loaded_nets ls_nn_apply - { filename; nb_inputs; nb_outputs; ty_data; nier }; + { filename; nb_inputs; nb_outputs; ty_data; nir }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_nn_apply)) in Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr -let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr = +let register_nn_as_array env nb_inputs nb_outputs filename ?nir mstr = let name = "AsArray" in let th_uc = Pmodule.create_module env (Ident.id_fresh name) in let nn = @@ -81,7 +81,7 @@ let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr = in let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] ty_data in Term.Hls.add loaded_nets ls_model - { filename; nb_inputs; nb_outputs; ty_data; nier }; + { filename; nb_inputs; nb_outputs; ty_data; nir }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_model)) @@ -122,21 +122,21 @@ let onnx_parser = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in Hashtbl.findi_or_add h ~default:(fun filename -> - let model = Onnx.Simple.parse filename in + let model = Onnx.Reader.from_file filename in match model with | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs; nier } -> - let nier = - match nier with + | Ok { n_inputs; n_outputs; nir } -> + let nir = + match nir with | Error msg -> Logs.warn (fun m -> m "Cannot build network intermediate representation:@ %s" msg); None - | Ok nier -> Some nier + | Ok nir -> Some nir in Wstdlib.Mstr.empty - |> register_nn_as_tuple env n_inputs n_outputs filename ?nier - |> register_nn_as_array env n_inputs n_outputs filename ?nier)) + |> register_nn_as_tuple env n_inputs n_outputs filename ?nir + |> register_nn_as_array env n_inputs n_outputs filename ?nir)) let ovo_parser = Env.Wenv.memoize 13 (fun env -> @@ -205,7 +205,7 @@ type nn = { and nn_format = | NNet - | ONNX of Ir.Nier_simple.t option [@printer fun fmt _ -> Fmt.pf fmt "<nier>"] + | ONNX of Nir.Ngraph.t option [@printer fun fmt _ -> Fmt.pf fmt "<nir>"] [@@deriving show] let nets = Term.Hls.create 10 @@ -236,24 +236,24 @@ let create_nn_nnet env filename = } let create_nn_onnx env filename = - let model = Onnx.Simple.parse filename in + let model = Onnx.Reader.from_file filename in match model with | Error s -> Loc.errorm "%s" s - | Ok { n_inputs; n_outputs; nier } -> - let nier = - match nier with + | Ok { n_inputs; n_outputs; nir } -> + let nir = + match nir with | Error msg -> Logs.warn (fun m -> m "Cannot build network intermediate representation:@ %s" msg); None - | Ok nier -> Some nier + | Ok nir -> Some nir in { nn_nb_inputs = n_inputs; nn_nb_outputs = n_outputs; nn_ty_elt = ty_float64_t env; nn_filename = filename; - nn_format = ONNX nier; + nn_format = ONNX nir; } let create_nn = diff --git a/src/language.mli b/src/language.mli index 72be9f3f461a984bfbd9a6351f335b03c69929ee..a2ea56e51e4564e9fedb7d537f7eedc59b2aeed6 100644 --- a/src/language.mli +++ b/src/language.mli @@ -27,7 +27,7 @@ type nn_shape = private { nb_outputs : int; ty_data : Ty.ty; filename : string; - nier : Ir.Nier_simple.t option; + nir : Nir.Ngraph.t option; } type svm_shape = private { @@ -84,7 +84,7 @@ type nn = private { and nn_format = | NNet - | ONNX of Ir.Nier_simple.t option + | ONNX of Nir.Ngraph.t option [@@deriving show] val create_nn : Env.env -> [ `NNet | `ONNX ] -> string -> Term.lsymbol diff --git a/src/main.ml b/src/main.ml index d9b804129873fa285e5bd4ff8bc9ce7626a274b5..906e9712ce9d3c6faf8d1cf8fdd1b2db6a47f708 100644 --- a/src/main.ml +++ b/src/main.ml @@ -265,7 +265,7 @@ let verify_cmd = Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE") in let onnx_out_dir = - let doc = "Write NIER as ONNX file in $(docv)." in + let doc = "Write NIR as ONNX file in $(docv)." in Arg.( value & opt (some string) None diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 322bcf3074aebc4231ff3d2f53ad54466fc4b965..22a140e58791d436dedef6a2b50b5bd9977eef29 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -46,7 +46,7 @@ let tempfile = | None -> Stdlib.Filename.temp_file "caisar" ".onnx" let create_new_nn env input_vars outputs : string = - let module IR = Ir.Nier_simple in + let module IR = Nir in let th_f64 = Symbols.Float64.create env in let th_model = Symbols.Model.create env in let th_nn = Symbols.NN.create env in @@ -58,7 +58,7 @@ let create_new_nn env input_vars outputs : string = let get_input = Why3.Term.Hls.memo 10 (fun ls -> let i = Why3.Term.Mls.find_exn UnknownLogicSymbol ls input_vars in - Ir.Nier_simple.Node.gather_int input i) + IR.Node.gather_int input i) in let cache = Why3.Term.Hterm.create 17 in let nn_cache = Stdlib.Hashtbl.create 17 in @@ -68,7 +68,7 @@ let create_new_nn env input_vars outputs : string = let converted_args = List.map ~f:convert_term old_nn_args in let id = ( old_nn.Language.nn_filename, - List.map converted_args ~f:(fun n -> n.Ir.Nier_simple.id) ) + List.map converted_args ~f:(fun n -> n.Nir.Node.id) ) in match Stdlib.Hashtbl.find_opt nn_cache id with | None -> @@ -77,15 +77,15 @@ let create_new_nn env input_vars outputs : string = node_nn | Some node_nn -> node_nn and convert_old_nn_aux old_nn converted_args = - let old_nn_nier = - match Onnx.Simple.parse old_nn.Language.nn_filename with + let old_nn_nir = + match Onnx.Reader.from_file old_nn.Language.nn_filename with | Error s -> Logging.code_error ~src (fun m -> m "No parsed NN '%s': %s" old_nn.nn_filename s) - | Ok { nier = Error s; _ } -> + | Ok { nir = Error s; _ } -> Logging.code_error ~src (fun m -> - m "No NIER available for NN '%s': %s" old_nn.nn_filename s) - | Ok { nier = Ok g; _ } -> g + m "No NIR available for NN '%s': %s" old_nn.nn_filename s) + | Ok { nir = Ok g; _ } -> g in (* Create the graph to replace the old input of the nn *) let input () = @@ -93,18 +93,17 @@ let create_new_nn env input_vars outputs : string = let node = IR.Node.create (Concat { inputs = converted_args; axis = 0 }) in - Ir.Nier_simple.Node.reshape (IR.input_shape old_nn_nier) node + IR.Node.reshape (IR.Ngraph.input_shape old_nn_nir) node in let out = - IR.Node.replace_input input (IR.output old_nn_nier) - |> Ir.Nier_simple.Node.reshape - (Ir.Nier_simple.Shape.of_array [| old_nn.nn_nb_outputs |]) + IR.Node.replace_input input (IR.Ngraph.output old_nn_nir) + |> IR.Node.reshape (Nir.Shape.of_array [| old_nn.nn_nb_outputs |]) in out and convert_old_nn_at_old_index old_nn old_index old_nn_args = let out = convert_old_nn old_nn old_nn_args in let old_index = Why3.Number.to_small_integer old_index in - Ir.Nier_simple.Node.gather_int out old_index + Nir.Node.gather_int out old_index and convert_term term = match Why3.Term.Hterm.find_opt cache term with | None -> @@ -112,7 +111,7 @@ let create_new_nn env input_vars outputs : string = Why3.Term.Hterm.add cache term n; n | Some n -> n - and convert_term_aux term : IR.GFloat.Node.t = + and convert_term_aux term : IR.Node.t = if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty) then Logging.user_error ?loc:term.t_loc (fun m -> @@ -122,7 +121,7 @@ let create_new_nn env input_vars outputs : string = IR.Node.create (Constant { - data = IR.GenTensor.create_1_float (Utils.float_of_real_constant r); + data = IR.Gentensor.create_1_float (Utils.float_of_real_constant r); }) | Tapp (ls, []) -> get_input ls | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add -> @@ -135,12 +134,12 @@ let create_new_nn env input_vars outputs : string = match b.t_node with | Tconst (Why3.Constant.ConstReal r) -> let f = Utils.float_of_real_constant r in - Ir.Nier_simple.Node.div_float (convert_term a) f + Nir.Node.div_float (convert_term a) f | _ -> IR.Node.create (Div { input1 = convert_term a; input2 = convert_term b })) | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg -> - Ir.Nier_simple.Node.mul_float (convert_term a) (-1.) + Nir.Node.mul_float (convert_term a) (-1.) | Tapp ( ls_get (* [ ] *), [ @@ -195,11 +194,11 @@ let create_new_nn env input_vars outputs : string = let output = IR.Node.concat_0 outputs in assert ( IR.Shape.equal output.shape (IR.Shape.of_array [| List.length outputs |])); - let nn = IR.create output in - Logs.debug ~src:src_show (fun f -> f "@.%s@." (IR.grapheasy nn)); - Logs.debug ~src (fun f -> f "@[<v>%a@]@." IR.pp_debug nn); + let nn = IR.Ngraph.create output in + Logs.debug ~src:src_show (fun f -> f "@.%s@." (IR.Ngraph.grapheasy nn)); + Logs.debug ~src (fun f -> f "@[<v>%a@]@." IR.Ngraph.pp_debug nn); let filename = tempfile () in - Onnx.Simple.write nn filename; + Onnx.Writer.to_file nn filename; filename (* Choose the term pattern for starting the conversion to ONNX. *) diff --git a/src/transformations/nn2smt.ml b/src/transformations/nn2smt.ml index 56f24a1324b4d7077d5de3e414b7d134f98959f1..03cba10ecf8a45c67332315556c3a6435d1e9626 100644 --- a/src/transformations/nn2smt.ml +++ b/src/transformations/nn2smt.ml @@ -21,7 +21,7 @@ (**************************************************************************) open Base -module IR = Ir.Nier_simple +module IR = Nir.Ngraph let src = Logs.Src.create "NN2SMT" ~doc:"Encoding of neural networks into SMT-LIB" @@ -49,13 +49,13 @@ let relu env expr = in Why3.Term.t_app_infer relu_s [ expr ] -let rec terms_of_nier m d node index ty_inputs env input_vars input_terms = +let rec terms_of_nir m d node index ty_inputs env input_vars input_terms = let mi = Option.value ~default:(Map.empty (module Int)) @@ Map.find !m node in match Map.find mi index with | Some ls -> ls | None -> let ls, decl = - terms_of_nier_aux m d node index ty_inputs env input_vars input_terms + terms_of_nir_aux m d node index ty_inputs env input_vars input_terms in Queue.enqueue d (Why3.Theory.create_decl decl); Queue.enqueue d @@ -64,14 +64,15 @@ let rec terms_of_nier m d node index ty_inputs env input_vars input_terms = m := Map.set !m ~key:node ~data:mi; ls -and t_app_of_nier m d node index ty_inputs env input_vars input_terms = - let ls = terms_of_nier m d node index ty_inputs env input_vars input_terms in +and t_app_of_nir m d node index ty_inputs env input_vars input_terms = + let ls = terms_of_nir m d node index ty_inputs env input_vars input_terms in Why3.Term.fs_app ls (if List.is_empty input_vars then [] else input_terms) ty_inputs -and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms = - let preid = Why3.Ident.id_fresh (Fmt.str "n%i_%i" node.IR.id index) in +and terms_of_nir_aux m d (node : Nir.Node.t) index ty_inputs env input_vars + input_terms = + let preid = Why3.Ident.id_fresh (Fmt.str "n%i_%i" node.id index) in let ls = Why3.Term.create_fsymbol preid (List.map input_vars ~f:(fun _ -> ty_inputs)) @@ -82,13 +83,13 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms = Paxiom ps t *) match node.descr with | Constant { data = Float ba } -> - let id = IR.Shape.unrow_major node.shape index in - Utils.term_of_float env (IR.Tensor.get ba id) + let id = Nir.Shape.unrow_major node.shape index in + Utils.term_of_float env (Nir.Tensor.get ba id) | Matmul { input1; input2 } -> (* [|... ; _; n |], [| ...; n; _ |] *) - let id = IR.Shape.unrow_major node.shape index in + let id = Nir.Shape.unrow_major node.shape index in let broadcast shape id_dst = - let len = IR.Shape.rank shape in + let len = Nir.Shape.rank shape in let id_src = Array.create ~len 0 in for i = 0 to len - 3 do id_src.(i) <- id_dst.(i) @@ -101,18 +102,18 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms = id2.(Array.length id2 - 1) <- id.(Array.length id - 1); let acc = Sequence.init - (IR.Shape.get input1.shape (IR.Shape.rank input1.shape - 1)) + (Nir.Shape.get input1.shape (Nir.Shape.rank input1.shape - 1)) ~f:(fun i -> (* id1 = [...; _; i ] id2 = [...; i; _ ] *) id1.(Array.length id1 - 1) <- i; id2.(Array.length id2 - 2) <- i; - let i1 = IR.Shape.row_major input1.shape id1 in - let i2 = IR.Shape.row_major input2.shape id2 in + let i1 = Nir.Shape.row_major input1.shape id1 in + let i2 = Nir.Shape.row_major input2.shape id2 in let a1 = - t_app_of_nier m d input1 i1 ty_inputs env input_vars input_terms + t_app_of_nir m d input1 i1 ty_inputs env input_vars input_terms in let a2 = - t_app_of_nier m d input2 i2 ty_inputs env input_vars input_terms + t_app_of_nir m d input2 i2 ty_inputs env input_vars input_terms in mul a1 a2 env) in @@ -121,29 +122,29 @@ and terms_of_nier_aux m d node index ty_inputs env input_vars input_terms = | Input _ -> List.nth_exn input_terms index | Add { input1; input2 } -> let t1 = - t_app_of_nier m d input1 index ty_inputs env input_vars input_terms + t_app_of_nir m d input1 index ty_inputs env input_vars input_terms in let t2 = - t_app_of_nier m d input2 index ty_inputs env input_vars input_terms + t_app_of_nir m d input2 index ty_inputs env input_vars input_terms in sum t1 t2 env | Div { input1; input2 } -> let t1 = - t_app_of_nier m d input1 index ty_inputs env input_vars input_terms + t_app_of_nir m d input1 index ty_inputs env input_vars input_terms in let t2 = - t_app_of_nier m d input2 index ty_inputs env input_vars input_terms + t_app_of_nir m d input2 index ty_inputs env input_vars input_terms in div t1 t2 env | ReLu { input } -> let t = - t_app_of_nier m d input index ty_inputs env input_vars input_terms + t_app_of_nir m d input index ty_inputs env input_vars input_terms in relu env t | _ -> Logging.not_implemented_yet (fun f -> f "ONNX operator %a is not implemented for nn2smt transformation" - IR.Node.pp node) + Nir.Node.pp node) in ( ls, Why3.Decl.create_logic_decl [ Why3.Decl.make_ls_defn ls input_vars v_term ] @@ -159,20 +160,18 @@ module MTermL = Why3.Extmap.Make (struct type t = T.t list [@@deriving ord] end) -let app_terms_of_nier_output m d (nn : Language.nn) env index tl = +let app_terms_of_nir_output m d (nn : Language.nn) env index tl = match nn.nn_format with | NNet -> Logging.not_implemented_yet (fun f -> f "NNet to SMT conversion") | ONNX None -> Logging.code_error ~src (fun f -> f "No ONNX to convert") | ONNX (Some g) -> let vtl = List.fold tl ~init:Why3.Term.Mvs.empty ~f:Why3.Term.t_freevars in - let m' = ref (MTermL.find_def (Map.empty (module IR.Node)) tl !m) in + let m' = ref (MTermL.find_def (Map.empty (module Nir.Node)) tl !m) in let t = if Why3.Term.Mvs.is_empty vtl then (* use global constant *) - let ls = - terms_of_nier m' d (IR.output g) index nn.nn_ty_elt env [] tl - in + let ls = terms_of_nir m' d (IR.output g) index nn.nn_ty_elt env [] tl in Why3.Term.fs_app ls [] nn.nn_ty_elt else (* use global functions *) @@ -186,7 +185,7 @@ let app_terms_of_nier_output m d (nn : Language.nn) env index tl = in let input_terms = List.map input_vars ~f:Why3.Term.t_var in let ls = - terms_of_nier m' d (IR.output g) index nn.nn_ty_elt env input_vars + terms_of_nir m' d (IR.output g) index nn.nn_ty_elt env input_vars input_terms in Why3.Term.fs_app ls tl nn.nn_ty_elt @@ -224,7 +223,7 @@ let actual_nn_flow env = match (Language.lookup_nn ls_nn, Language.lookup_vector ls) with | Some ({ nn_nb_inputs; _ } as nn), Some vector_length -> assert (nn_nb_inputs = vector_length && vector_length = List.length tl); - app_terms_of_nier_output m d nn env + app_terms_of_nir_output m d nn env (Why3.Number.to_small_integer index) tl | _, _ -> diff --git a/src/transformations/nn2smt.mli b/src/transformations/nn2smt.mli index 6f92b5a3a81702912bf462acecc37ece5385ebe3..0cb0eb9b33800dc7fbb3ffd5ff858c752b60dfa6 100644 --- a/src/transformations/nn2smt.mli +++ b/src/transformations/nn2smt.mli @@ -20,14 +20,14 @@ (* *) (**************************************************************************) -(** This module converts a valid NIER into WhyML terms. +(** This module converts a valid NIR into WhyML terms. - NIER encapsulate parameters in tensor forms and a computation graph with + NIR encapsulate parameters in tensor forms and a computation graph with various operations. WhyML language supports multidimensional arrays as well. - This module translates NIER data into a list of WhyML terms, describing the + This module translates NIR data into a list of WhyML terms, describing the control flow of the neural network. Variables are stored inside of an - environment, their shape being either provided by the NIER or inferred with + environment, their shape being either provided by the NIR or inferred with the expected result of ONNX operations. *) val trans : Why3.Env.env -> Why3.Task.task Why3.Trans.trans diff --git a/src/verification.ml b/src/verification.ml index e63540bffa5fbf6bcdb2ad4b413fcfb0720c1589..96bd03005badf2f6d4825dbd5202534abdb546e8 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -90,27 +90,27 @@ let create_env loadpath = (loadpath @ stdlib @ Whyconf.loadpath (Whyconf.get_main config)), config ) -let write_nier_as_onnx onnx_out_dir = +let write_nir_as_onnx onnx_out_dir = Language.iter_nn (fun ls nn -> match nn.nn_format with - | ONNX (Some nn_nier) -> ( + | ONNX (Some nn_nir) -> ( try if not (Stdlib.Sys.file_exists onnx_out_dir) then Stdlib.Sys.mkdir onnx_out_dir 0o755; let filename = - Fmt.str "%s%s%a.nier.onnx" onnx_out_dir Stdlib.Filename.dir_sep + Fmt.str "%s%s%a.nir.onnx" onnx_out_dir Stdlib.Filename.dir_sep Pretty.print_ls ls in - Onnx.Simple.write nn_nier filename; - Logs.debug ~src:Logging.src_nier (fun m -> - m "@[Wrote NIER as ONNX in file '%s'@]" filename) + Onnx.Writer.to_file nn_nir filename; + Logs.debug ~src:Logging.src_nir (fun m -> + m "@[Wrote NIR as ONNX in file '%s'@]" filename) with Sys_error msg -> Logging.user_error (fun m -> - m "@[Cannot write NIER as ONNX in folder '%s': '%s'@]" onnx_out_dir - msg)) + m "@[Cannot write NIR as ONNX in folder '%s': '%s'@]" onnx_out_dir msg) + ) | _ -> Logs.warn (fun m -> - m "@[No available NIER to write as ONNX for logic symbol '%a'@]" + m "@[No available NIR to write as ONNX for logic symbol '%a'@]" Pretty.print_ls ls)) let answer_saver_on_dataset limit config env config_prover ~dataset task = @@ -488,5 +488,5 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern tasks) mstr_theory in - Option.iter ~f:write_nier_as_onnx onnx_out_dir; + Option.iter ~f:write_nir_as_onnx onnx_out_dir; verification_result diff --git a/src/verification.mli b/src/verification.mli index 935f619fade3bd23f0cf2704bee4150fc0b61551..71d1b7839351e3a3c4c76a7f1be0d78482f87e89 100644 --- a/src/verification.mli +++ b/src/verification.mli @@ -76,7 +76,7 @@ val verify : is a theory:goals list each identifying only the goals of a theory to verify. @param onnx_out_dir - is the directory in which to write the ONNX files generated from the NIER. + is the directory in which to write the ONNX files generated from the NIR. @return for each theory, an [answer] for each goal of the theory, respecting the order of appearance. *) diff --git a/tests/acasxu.t b/tests/acasxu.t index 0d19a6594cbffbba1f011a43458c1e182c8ba949..06b8c65a4b99e2268ee7635724e3034f50037634 100644 --- a/tests/acasxu.t +++ b/tests/acasxu.t @@ -205,7 +205,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le y (1500.0:t) @@ -269,14 +269,14 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_9.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) nn_onnx, (Interpreter_types.Model (Interpreter_types.ONNX, { Language.nn_nb_inputs = 5; nn_nb_outputs = 5; nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le y1 (1500.0:t) /\ le y2 (1500.0:t) @@ -350,7 +350,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le y3 y4 /\ le y4 y3 @@ -426,7 +426,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le y5 y6 /\ le y6 y5 @@ -504,7 +504,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le y7 y8 /\ le y8 y7 @@ -586,7 +586,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{NN-Gen-Term} new goal:le (0.0:t) y9 /\ le y9 (0.0:t) diff --git a/tests/acasxu_ci.t b/tests/acasxu_ci.t index c55a2f7781913e9a97a0e8cc1562040333181606..b9b6cb3c669275ed2e370bf90f2a2fa7fe903e62 100644 --- a/tests/acasxu_ci.t +++ b/tests/acasxu_ci.t @@ -166,7 +166,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{ProverSpec} Prover-tailored specification: @@ -229,7 +229,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 nn_onnx1, (Interpreter_types.Model @@ -237,7 +237,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_9.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) [DEBUG]{ProverSpec} Prover-tailored specification: 0.0 <= x0 x0 <= 60760.0 @@ -308,7 +308,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{ProverSpec} Prover-tailored specification: @@ -382,7 +382,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{ProverSpec} Prover-tailored specification: @@ -498,7 +498,7 @@ Test verify on acasxu nn_ty_elt = t; nn_filename = "./../examples/acasxu/nets/onnx/ACASXU_1_1.onnx"; - nn_format = <nier> })) + nn_format = <nir> })) vector, 5 [DEBUG]{ProverSpec} Prover-tailored specification: diff --git a/tests/nier_to_onnx.t b/tests/nier_to_onnx.t index d77897c5db91be6e61a94e3b10273908ed802116..9bbd9c34302db2729aa58e6cd47c225b3452ef42 100644 --- a/tests/nier_to_onnx.t +++ b/tests/nier_to_onnx.t @@ -4,8 +4,8 @@ Test verify $ bin/pyrat.py --version PyRAT 1.1 - $ caisar verify --format whyml --prover=PyRAT --ltag=NIER --onnx-out-dir="out" - 2>&1 <<EOF - > theory NIER_to_ONNX + $ caisar verify --format whyml --prover=PyRAT --ltag=NIR --onnx-out-dir="out" - 2>&1 <<EOF + > theory NIR_to_ONNX > use ieee_float.Float64 > use caisar.types.Vector > use caisar.model.Model @@ -20,12 +20,12 @@ Test verify > (0.5:t) .< (nn @@ i)[0] .< (0.5:t) > end > EOF - [DEBUG]{NIER} Wrote NIER as ONNX in file 'out/nn_onnx.nier.onnx' + [DEBUG]{NIR} Wrote NIR as ONNX in file 'out/nn_onnx.nir.onnx' Goal G: Unknown () Data should be 0.135 $ python3 bin/inspect_onnx.py - out/nn_onnx.nier.onnx has 1 input nodes + out/nn_onnx.nir.onnx has 1 input nodes {'name': '0', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '1'}, {'dimValue': '3'}]}}}} 1 files checked diff --git a/utils/logging.ml b/utils/logging.ml index afe6c3d5132e8bebfc474208afec04f9ed67bbc6..c3f129f99bef70369aa905a3a3d200b2c3d9b00e 100644 --- a/utils/logging.ml +++ b/utils/logging.ml @@ -30,7 +30,7 @@ let src_prover_call = Logs.Src.create "ProverCall" ~doc:"Prover call" let src_interpret_goal = Logs.Src.create "InterpretGoal" ~doc:"Goal interpretation" -let src_nier = Logs.Src.create "NIER" ~doc:"Neural Intermediate Representation" +let src_nir = Logs.Src.create "NIR" ~doc:"Neural Intermediate Representation" let src_stack_trace = Logs.Src.create "StackTrace" ~doc:"Print stack trace" let all_srcs () = Logs.Src.list () diff --git a/utils/logging.mli b/utils/logging.mli index 6410b1a5911380629712fbc874a397d3808b97f6..d05c477a90ee8739fc6d08a980bd7a06b68feb3d 100644 --- a/utils/logging.mli +++ b/utils/logging.mli @@ -26,7 +26,7 @@ val src_autodetect : Logs.src val src_prover_spec : Logs.src val src_prover_call : Logs.src val src_interpret_goal : Logs.src -val src_nier : Logs.src +val src_nir : Logs.src val src_stack_trace : Logs.src val all_srcs : unit -> Logs.src list