Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
caisar
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pub
caisar
Commits
b26b9986
Commit
b26b9986
authored
2 years ago
by
Julien Girard-Satabin
Browse files
Options
Downloads
Patches
Plain Diff
[NIER] Added utils to NIER: infer matrix size in both forward and backward pass.
parent
5947b05e
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
lib/ir/nier_cfg.ml
+97
-1
97 additions, 1 deletion
lib/ir/nier_cfg.ml
lib/ir/nier_cfg.mli
+7
-0
7 additions, 0 deletions
lib/ir/nier_cfg.mli
with
104 additions
and
1 deletion
lib/ir/nier_cfg.ml
+
97
−
1
View file @
b26b9986
...
@@ -101,7 +101,8 @@ module Tensor = struct
...
@@ -101,7 +101,8 @@ module Tensor = struct
~
init
:
0
(
Array
.
to_list
idx
)
factors
~
init
:
0
(
Array
.
to_list
idx
)
factors
with
with
|
List
.
Or_unequal_lengths
.
Ok
i
->
i
|
List
.
Or_unequal_lengths
.
Ok
i
->
i
|
List
.
Or_unequal_lengths
.
Unequal_lengths
->
failwith
"Unequal lengths"
|
List
.
Or_unequal_lengths
.
Unequal_lengths
->
failwith
"Unequal lengths in get_flatnd_idx"
in
in
List
.
nth_exn
flt
coord_in_data
List
.
nth_exn
flt
coord_in_data
...
@@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct
...
@@ -347,6 +348,101 @@ module NierCFG (I : VInput) = struct
let
data_node_of
n
g
=
let
data_node_of
n
g
=
fold_pred
(
fun
v
_
->
if
Node
.
is_data_node
v
then
Some
v
else
None
)
g
n
None
fold_pred
(
fun
v
_
->
if
Node
.
is_data_node
v
then
Some
v
else
None
)
g
n
None
let
infer_shape
g
n
in_shape
~
on_backward
=
let
op
=
Node
.
get_op
n
in
match
op
with
|
Node
.
Add
->
(
match
data_node_of
n
g
with
|
Some
d_n
->
Node
.
get_shape
d_n
|
None
->
failwith
"Error, Add operator lacks a data node"
)
|
Node
.
ReLu
->
in_shape
|
Node
.
Matmul
->
let
pad_left
=
function
|
[]
->
failwith
"Impossible to pad empty shape"
|
[
a
]
->
[
1
;
a
]
|
x
->
x
in
let
pad_right
=
function
|
[]
->
failwith
"Impossible to pad empty shape"
|
[
a
]
->
[
a
;
1
]
|
x
->
x
in
let
rec
one_padding
l
i
=
if
i
<=
0
then
l
else
one_padding
(
1
::
l
)
(
i
-
1
)
in
let
dn_shape
=
match
data_node_of
n
g
with
|
Some
dn
->
Node
.
get_shape
dn
|
None
->
failwith
"Error, MatMul operator lacks a data node"
in
(* Expected semantic:
* Matrix multiplication C = AB
* A (shape [n;m]); B (shape [m;p]); C (shape [n;p])
* shape of b: b_sh
* shape of a: a_sh
* shape of c: c_sh
* It is expected here that B is the shape of the node
* yielding the data tensor in the NIER
*)
let
check_matmul_size_ba
~
b_sh
~
a_sh
=
let
bdim2
=
pad_left
b_sh
in
let
adim2
=
pad_right
a_sh
in
let
bdim
=
one_padding
bdim2
(
List
.
length
adim2
-
List
.
length
bdim2
)
in
let
adim
=
one_padding
adim2
(
List
.
length
bdim2
-
List
.
length
adim2
)
in
let
rec
infer_csize
acc
ad
bd
=
match
(
ad
,
bd
)
with
|
[
m
;
n
]
,
[
nn
;
p
]
->
if
nn
=
n
then
(
n
,
List
.
append
(
List
.
rev
acc
)
[
m
;
p
])
else
failwith
"size of matrices not adequate"
|
a
::
la
,
b
::
lb
->
if
a
=
b
then
infer_csize
(
a
::
acc
)
la
lb
else
if
a
=
1
then
infer_csize
(
b
::
acc
)
la
lb
else
if
b
=
1
then
infer_csize
(
a
::
acc
)
la
lb
else
failwith
"Checking matmul_size failed: one discordance"
|
_
,
_
->
failwith
"Checking matmul_size failed"
in
infer_csize
[]
bdim
adim
in
let
check_matmul_size_bc
~
b_sh
~
c_sh
=
let
bdim2
=
pad_left
b_sh
in
let
cdim2
=
pad_right
c_sh
in
let
bdim
=
one_padding
bdim2
(
List
.
length
cdim2
-
List
.
length
bdim2
)
in
let
cdim
=
one_padding
cdim2
(
List
.
length
bdim2
-
List
.
length
cdim2
)
in
let
rec
infer_asize
acc
bd
cd
=
match
(
bd
,
cd
)
with
|
[
m
;
p
]
,
[
n
;
pp
]
->
if
pp
=
p
then
(
n
,
List
.
append
(
List
.
rev
acc
)
[
n
;
m
])
else
failwith
"size of matrices not adequate"
|
b
::
lb
,
c
::
lc
->
if
b
=
c
then
infer_asize
(
b
::
acc
)
lb
lc
else
if
b
=
1
then
infer_asize
(
b
::
acc
)
lb
lc
else
if
c
=
1
then
infer_asize
(
c
::
acc
)
lb
lc
else
failwith
"Checking matmul_size failed: one discordance"
|
_
,
_
->
failwith
"Checking matmul_size failed"
in
infer_asize
[]
bdim
cdim
in
if
on_backward
then
Array
.
of_list
@@
snd
(
check_matmul_size_bc
~
b_sh
:
(
Array
.
to_list
dn_shape
)
~
c_sh
:
(
Array
.
to_list
in_shape
))
else
Array
.
of_list
@@
snd
(
check_matmul_size_ba
~
b_sh
:
(
Array
.
to_list
in_shape
)
~
a_sh
:
(
Array
.
to_list
dn_shape
))
|
a
->
failwith
(
Printf
.
sprintf
"operator %s not supported"
(
Node
.
str_op
a
))
end
end
module
NierCFGInt
=
NierCFG
(
struct
module
NierCFGInt
=
NierCFG
(
struct
...
...
This diff is collapsed.
Click to expand it.
lib/ir/nier_cfg.mli
+
7
−
0
View file @
b26b9986
...
@@ -252,6 +252,13 @@ module NierCFGFloat : sig
...
@@ -252,6 +252,13 @@ module NierCFGFloat : sig
predecessors of [n]*)
predecessors of [n]*)
val
data_node_of
:
vertex
->
t
->
vertex
option
val
data_node_of
:
vertex
->
t
->
vertex
option
(** [infer_shape g n sh o_b] returns the inferred shape of the output of node
[n] in NIER [g] with input shape [sh]. Shape inference is made using the
node operator and its predecessors shapes. [o_b] is true when performing
backward propagation, to choose which matrix size to consider. *)
val
infer_shape
:
t
->
vertex
->
Node
.
shape
->
on_backward
:
bool
->
Node
.
shape
end
end
(** {1 Pretty printers} *)
(** {1 Pretty printers} *)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment