Skip to content
Snippets Groups Projects
Commit b26b9986 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[NIER] Added utils to NIER: infer matrix size in both forward and backward pass.

parent 5947b05e
No related branches found
No related tags found
No related merge requests found
...@@ -101,7 +101,8 @@ module Tensor = struct ...@@ -101,7 +101,8 @@ module Tensor = struct
~init:0 (Array.to_list idx) factors ~init:0 (Array.to_list idx) factors
with with
| List.Or_unequal_lengths.Ok i -> i | List.Or_unequal_lengths.Ok i -> i
| List.Or_unequal_lengths.Unequal_lengths -> failwith "Unequal lengths" | List.Or_unequal_lengths.Unequal_lengths ->
failwith "Unequal lengths in get_flatnd_idx"
in in
List.nth_exn flt coord_in_data List.nth_exn flt coord_in_data
...@@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct ...@@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct
let data_node_of n g = let data_node_of n g =
fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None
let infer_shape g n in_shape ~on_backward =
let op = Node.get_op n in
match op with
| Node.Add -> (
match data_node_of n g with
| Some d_n -> Node.get_shape d_n
| None -> failwith "Error, Add operator lacks a data node")
| Node.ReLu -> in_shape
| Node.Matmul ->
let pad_left = function
| [] -> failwith "Impossible to pad empty shape"
| [ a ] -> [ 1; a ]
| x -> x
in
let pad_right = function
| [] -> failwith "Impossible to pad empty shape"
| [ a ] -> [ a; 1 ]
| x -> x
in
let rec one_padding l i =
if i <= 0 then l else one_padding (1 :: l) (i - 1)
in
let dn_shape =
match data_node_of n g with
| Some dn -> Node.get_shape dn
| None -> failwith "Error, MatMul operator lacks a data node"
in
(* Expected semantic:
* Matrix multiplication C = AB
* A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
* shape of b: b_sh
* shape of a: a_sh
* shape of c: c_sh
* It is expected here that B is the shape of the node
* yielding the data tensor in the NIER
*)
let check_matmul_size_ba ~b_sh ~a_sh =
let bdim2 = pad_left b_sh in
let adim2 = pad_right a_sh in
let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in
let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in
let rec infer_csize acc ad bd =
match (ad, bd) with
| [ m; n ], [ nn; p ] ->
if nn = n
then (n, List.append (List.rev acc) [ m; p ])
else failwith "size of matrices not adequate"
| a :: la, b :: lb ->
if a = b
then infer_csize (a :: acc) la lb
else if a = 1
then infer_csize (b :: acc) la lb
else if b = 1
then infer_csize (a :: acc) la lb
else failwith "Checking matmul_size failed: one discordance"
| _, _ -> failwith "Checking matmul_size failed"
in
infer_csize [] bdim adim
in
let check_matmul_size_bc ~b_sh ~c_sh =
let bdim2 = pad_left b_sh in
let cdim2 = pad_right c_sh in
let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in
let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in
let rec infer_asize acc bd cd =
match (bd, cd) with
| [ m; p ], [ n; pp ] ->
if pp = p
then (n, List.append (List.rev acc) [ n; m ])
else failwith "size of matrices not adequate"
| b :: lb, c :: lc ->
if b = c
then infer_asize (b :: acc) lb lc
else if b = 1
then infer_asize (b :: acc) lb lc
else if c = 1
then infer_asize (c :: acc) lb lc
else failwith "Checking matmul_size failed: one discordance"
| _, _ -> failwith "Checking matmul_size failed"
in
infer_asize [] bdim cdim
in
if on_backward
then
Array.of_list
@@ snd
(check_matmul_size_bc ~b_sh:(Array.to_list dn_shape)
~c_sh:(Array.to_list in_shape))
else
Array.of_list
@@ snd
(check_matmul_size_ba ~b_sh:(Array.to_list in_shape)
~a_sh:(Array.to_list dn_shape))
| a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a))
end end
module NierCFGInt = NierCFG (struct module NierCFGInt = NierCFG (struct
......
...@@ -252,6 +252,13 @@ module NierCFGFloat : sig ...@@ -252,6 +252,13 @@ module NierCFGFloat : sig
predecessors of [n]*) predecessors of [n]*)
val data_node_of : vertex -> t -> vertex option val data_node_of : vertex -> t -> vertex option
(** [infer_shape g n sh o_b] returns the inferred shape of the output of node
[n] in NIER [g] with input shape [sh]. Shape inference is made using the
node operator and its predecessors shapes. [o_b] is true when performing
backward propagation, to choose which matrix size to consider. *)
val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape
end end
(** {1 Pretty printers} *) (** {1 Pretty printers} *)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment