From b26b9986022fe2ec228ae5dc41218666b35e3820 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Thu, 22 Sep 2022 10:51:44 +0200
Subject: [PATCH] [NIER] Added utils to NIER: infer matrix size in both forward
 and backward pass.

---
 lib/ir/nier_cfg.ml  | 98 ++++++++++++++++++++++++++++++++++++++++++++-
 lib/ir/nier_cfg.mli |  7 ++++
 2 files changed, 104 insertions(+), 1 deletion(-)

diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml
index 0cbde6b..5402666 100644
--- a/lib/ir/nier_cfg.ml
+++ b/lib/ir/nier_cfg.ml
@@ -101,7 +101,8 @@ module Tensor = struct
           ~init:0 (Array.to_list idx) factors
       with
       | List.Or_unequal_lengths.Ok i -> i
-      | List.Or_unequal_lengths.Unequal_lengths -> failwith "Unequal lengths"
+      | List.Or_unequal_lengths.Unequal_lengths ->
+        failwith "Unequal lengths in get_flatnd_idx"
     in
     List.nth_exn flt coord_in_data
 
@@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct
 
   let data_node_of n g =
     fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None
+
+  let infer_shape g n in_shape ~on_backward =
+    let op = Node.get_op n in
+    match op with
+    | Node.Add -> (
+      match data_node_of n g with
+      | Some d_n -> Node.get_shape d_n
+      | None -> failwith "Error, Add operator lacks a data node")
+    | Node.ReLu -> in_shape
+    | Node.Matmul ->
+      let pad_left = function
+        | [] -> failwith "Impossible to pad empty shape"
+        | [ a ] -> [ 1; a ]
+        | x -> x
+      in
+      let pad_right = function
+        | [] -> failwith "Impossible to pad empty shape"
+        | [ a ] -> [ a; 1 ]
+        | x -> x
+      in
+      let rec one_padding l i =
+        if i <= 0 then l else one_padding (1 :: l) (i - 1)
+      in
+      let dn_shape =
+        match data_node_of n g with
+        | Some dn -> Node.get_shape dn
+        | None -> failwith "Error, MatMul operator lacks a data node"
+      in
+      (* Expected semantic:
+       * Matrix multiplication C = AB
+       * A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
+       * shape of b: b_sh
+       * shape of a: a_sh
+       * shape of c: c_sh
+       * It is expected here that B is the shape of the node
+       * yielding the data tensor in the NIER
+       *)
+      let check_matmul_size_ba ~b_sh ~a_sh =
+        let bdim2 = pad_left b_sh in
+        let adim2 = pad_right a_sh in
+        let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in
+        let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in
+        let rec infer_csize acc ad bd =
+          match (ad, bd) with
+          | [ m; n ], [ nn; p ] ->
+            if nn = n
+            then (n, List.append (List.rev acc) [ m; p ])
+            else failwith "size of matrices not adequate"
+          | a :: la, b :: lb ->
+            if a = b
+            then infer_csize (a :: acc) la lb
+            else if a = 1
+            then infer_csize (b :: acc) la lb
+            else if b = 1
+            then infer_csize (a :: acc) la lb
+            else failwith "Checking matmul_size failed: one discordance"
+          | _, _ -> failwith "Checking matmul_size failed"
+        in
+        infer_csize [] bdim adim
+      in
+      let check_matmul_size_bc ~b_sh ~c_sh =
+        let bdim2 = pad_left b_sh in
+        let cdim2 = pad_right c_sh in
+        let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in
+        let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in
+        let rec infer_asize acc bd cd =
+          match (bd, cd) with
+          | [ m; p ], [ n; pp ] ->
+            if pp = p
+            then (n, List.append (List.rev acc) [ n; m ])
+            else failwith "size of matrices not adequate"
+          | b :: lb, c :: lc ->
+            if b = c
+            then infer_asize (b :: acc) lb lc
+            else if b = 1
+            then infer_asize (b :: acc) lb lc
+            else if c = 1
+            then infer_asize (c :: acc) lb lc
+            else failwith "Checking matmul_size failed: one discordance"
+          | _, _ -> failwith "Checking matmul_size failed"
+        in
+        infer_asize [] bdim cdim
+      in
+      if on_backward
+      then
+        Array.of_list
+        @@ snd
+             (check_matmul_size_bc ~b_sh:(Array.to_list dn_shape)
+                ~c_sh:(Array.to_list in_shape))
+      else
+        Array.of_list
+        @@ snd
+             (check_matmul_size_ba ~b_sh:(Array.to_list in_shape)
+                ~a_sh:(Array.to_list dn_shape))
+    | a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a))
 end
 
 module NierCFGInt = NierCFG (struct
diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli
index cfec49c..6f02318 100644
--- a/lib/ir/nier_cfg.mli
+++ b/lib/ir/nier_cfg.mli
@@ -252,6 +252,13 @@ module NierCFGFloat : sig
       predecessors of [n]*)
 
   val data_node_of : vertex -> t -> vertex option
+
+  (** [infer_shape g n sh o_b] returns the inferred shape of the output of node
+      [n] in NIER [g] with input shape [sh]. Shape inference is made using the
+      node operator and its predecessors shapes. [o_b] is true when performing
+      backward propagation, to choose which matrix size to consider. *)
+
+  val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape
 end
 
 (** {1 Pretty printers} *)
-- 
GitLab