From b26b9986022fe2ec228ae5dc41218666b35e3820 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Thu, 22 Sep 2022 10:51:44 +0200 Subject: [PATCH] [NIER] Added utils to NIER: infer matrix size in both forward and backward pass. --- lib/ir/nier_cfg.ml | 98 ++++++++++++++++++++++++++++++++++++++++++++- lib/ir/nier_cfg.mli | 7 ++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml index 0cbde6b..5402666 100644 --- a/lib/ir/nier_cfg.ml +++ b/lib/ir/nier_cfg.ml @@ -101,7 +101,8 @@ module Tensor = struct ~init:0 (Array.to_list idx) factors with | List.Or_unequal_lengths.Ok i -> i - | List.Or_unequal_lengths.Unequal_lengths -> failwith "Unequal lengths" + | List.Or_unequal_lengths.Unequal_lengths -> + failwith "Unequal lengths in get_flatnd_idx" in List.nth_exn flt coord_in_data @@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct let data_node_of n g = fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None + + let infer_shape g n in_shape ~on_backward = + let op = Node.get_op n in + match op with + | Node.Add -> ( + match data_node_of n g with + | Some d_n -> Node.get_shape d_n + | None -> failwith "Error, Add operator lacks a data node") + | Node.ReLu -> in_shape + | Node.Matmul -> + let pad_left = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ 1; a ] + | x -> x + in + let pad_right = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ a; 1 ] + | x -> x + in + let rec one_padding l i = + if i <= 0 then l else one_padding (1 :: l) (i - 1) + in + let dn_shape = + match data_node_of n g with + | Some dn -> Node.get_shape dn + | None -> failwith "Error, MatMul operator lacks a data node" + in + (* Expected semantic: + * Matrix multiplication C = AB + * A (shape [n;m]); B (shape [m;p]); C (shape [n;p]) + * shape of b: b_sh + * shape of a: a_sh + * shape of c: c_sh + * It is expected here that B is the shape of the node + * yielding the data tensor in the NIER + *) + let check_matmul_size_ba ~b_sh ~a_sh = + let bdim2 = pad_left b_sh in + let adim2 = pad_right a_sh in + let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in + let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in + let rec infer_csize acc ad bd = + match (ad, bd) with + | [ m; n ], [ nn; p ] -> + if nn = n + then (n, List.append (List.rev acc) [ m; p ]) + else failwith "size of matrices not adequate" + | a :: la, b :: lb -> + if a = b + then infer_csize (a :: acc) la lb + else if a = 1 + then infer_csize (b :: acc) la lb + else if b = 1 + then infer_csize (a :: acc) la lb + else failwith "Checking matmul_size failed: one discordance" + | _, _ -> failwith "Checking matmul_size failed" + in + infer_csize [] bdim adim + in + let check_matmul_size_bc ~b_sh ~c_sh = + let bdim2 = pad_left b_sh in + let cdim2 = pad_right c_sh in + let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in + let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in + let rec infer_asize acc bd cd = + match (bd, cd) with + | [ m; p ], [ n; pp ] -> + if pp = p + then (n, List.append (List.rev acc) [ n; m ]) + else failwith "size of matrices not adequate" + | b :: lb, c :: lc -> + if b = c + then infer_asize (b :: acc) lb lc + else if b = 1 + then infer_asize (b :: acc) lb lc + else if c = 1 + then infer_asize (c :: acc) lb lc + else failwith "Checking matmul_size failed: one discordance" + | _, _ -> failwith "Checking matmul_size failed" + in + infer_asize [] bdim cdim + in + if on_backward + then + Array.of_list + @@ snd + (check_matmul_size_bc ~b_sh:(Array.to_list dn_shape) + ~c_sh:(Array.to_list in_shape)) + else + Array.of_list + @@ snd + (check_matmul_size_ba ~b_sh:(Array.to_list in_shape) + ~a_sh:(Array.to_list dn_shape)) + | a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a)) end module NierCFGInt = NierCFG (struct diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli index cfec49c..6f02318 100644 --- a/lib/ir/nier_cfg.mli +++ b/lib/ir/nier_cfg.mli @@ -252,6 +252,13 @@ module NierCFGFloat : sig predecessors of [n]*) val data_node_of : vertex -> t -> vertex option + + (** [infer_shape g n sh o_b] returns the inferred shape of the output of node + [n] in NIER [g] with input shape [sh]. Shape inference is made using the + node operator and its predecessors shapes. [o_b] is true when performing + backward propagation, to choose which matrix size to consider. *) + + val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape end (** {1 Pretty printers} *) -- GitLab