Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
caisar
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pub
caisar
Commits
1feb8ad7
Commit
1feb8ad7
authored
11 months ago
by
Aymeric Varasse
Browse files
Options
Downloads
Patches
Plain Diff
[exps] Add onnx and training script for arithmetic
parent
57128258
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
examples/arithmetic/FNN_s42.onnx
+0
-0
0 additions, 0 deletions
examples/arithmetic/FNN_s42.onnx
examples/arithmetic/train.py
+148
-0
148 additions, 0 deletions
examples/arithmetic/train.py
with
148 additions
and
0 deletions
examples/arithmetic/FNN_s42.onnx
0 → 100644
+
0
−
0
View file @
1feb8ad7
File added
This diff is collapsed.
Click to expand it.
examples/arithmetic/train.py
0 → 100644
+
148
−
0
View file @
1feb8ad7
import
os
import
numpy
as
np
import
onnx
import
onnxruntime
as
ort
import
torch
import
torch.onnx
import
torch.optim
as
optim
from
loguru
import
logger
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
DataLoader
,
Dataset
SEED
=
42
STATE_PATH
=
f
"
FNN_s
{
SEED
}
.pth
"
ONNX_PATH
=
f
"
FNN_s
{
SEED
}
.onnx
"
INPUT_ARRAY
=
"
data.npy
"
TEST_INPUT_ARRAY
=
"
test_data.npy
"
TARGET_ARRAY
=
"
target.npy
"
TEST_TARGET_ARRAY
=
"
test_target.npy
"
torch
.
manual_seed
(
SEED
)
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
logger
.
info
(
f
"
Using device
{
device
}
"
)
num_epoch
=
2
batch_size
=
4
class
FNN
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
3
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
128
)
self
.
fc3
=
nn
.
Linear
(
128
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
ArithmeticDataset
(
Dataset
):
def
__init__
(
self
,
input_array
,
target_array
,
root_dir
):
self
.
input_array
=
np
.
load
(
input_array
).
astype
(
np
.
float32
)
self
.
target_array
=
np
.
load
(
target_array
).
astype
(
np
.
float32
)
self
.
root_dir
=
root_dir
def
__len__
(
self
):
return
len
(
self
.
input_array
)
def
__getitem__
(
self
,
idx
):
return
[
self
.
input_array
[
idx
],
self
.
target_array
[
idx
]]
def
train
(
state_dict
):
trainset
=
ArithmeticDataset
(
input_array
=
INPUT_ARRAY
,
target_array
=
TARGET_ARRAY
,
root_dir
=
os
.
path
.
dirname
(
INPUT_ARRAY
)
)
trainloader
=
DataLoader
(
trainset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
2
)
model
=
FNN
().
to
(
device
)
criterion
=
nn
.
MSELoss
()
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
for
epoch
in
range
(
num_epoch
):
running_loss
=
0.0
for
i
,
data
in
enumerate
(
trainloader
):
inputs
,
labels
=
data
[
0
].
to
(
device
),
data
[
1
].
to
(
device
)
optimizer
.
zero_grad
()
outputs
=
model
(
inputs
).
squeeze
().
to
(
device
)
loss
=
criterion
(
outputs
,
labels
)
loss
.
backward
()
optimizer
.
step
()
running_loss
+=
loss
.
item
()
if
i
%
2000
==
1999
:
logger
.
info
(
f
"
[
{
epoch
+
1
}
,
{
i
+
1
:
5
d
}
] loss:
{
running_loss
/
2000
:
.
3
f
}
"
)
running_loss
=
0.0
logger
.
info
(
"
Finished training
"
)
torch
.
save
(
model
.
state_dict
(),
state_dict
)
def
test
(
model_path
):
testset
=
ArithmeticDataset
(
input_array
=
TEST_INPUT_ARRAY
,
target_array
=
TEST_TARGET_ARRAY
,
root_dir
=
os
.
path
.
dirname
(
TEST_INPUT_ARRAY
)
)
testloader
=
DataLoader
(
testset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
2
)
net
=
FNN
().
to
(
device
)
net
.
load_state_dict
(
torch
.
load
(
model_path
))
error
=
0
total
=
len
(
testloader
)
*
batch_size
with
torch
.
no_grad
():
for
data
in
testloader
:
inputs
,
labels
=
data
[
0
].
to
(
device
),
data
[
1
].
to
(
device
)
outputs
=
net
(
inputs
).
squeeze
().
to
(
device
)
error
+=
((
outputs
-
labels
)
*
(
outputs
-
labels
)).
sum
().
data
.
cpu
()
logger
.
info
(
f
"
Average MSE of the network on the 10000 test inputs:
{
np
.
sqrt
(
error
/
total
)
:
.
3
f
}
"
)
def
export_model
(
model_path
,
onnx_path
):
model
=
FNN
().
to
(
device
)
model
.
load_state_dict
(
torch
.
load
(
model_path
))
x
=
torch
.
rand
(
1
,
3
,
device
=
device
)
torch
.
onnx
.
export
(
model
=
model
,
args
=
x
,
f
=
onnx_path
,
export_params
=
True
)
logger
.
info
(
"
Model exported successfully
"
)
test_onnx
(
model_path
,
onnx_path
)
def
test_onnx
(
model_path
,
onnx_path
):
model
=
FNN
().
to
(
device
)
model
.
load_state_dict
(
torch
.
load
(
model_path
))
onnx_model
=
onnx
.
load
(
onnx_path
)
onnx
.
checker
.
check_model
(
onnx_model
)
ort_session
=
ort
.
InferenceSession
(
onnx_path
,
providers
=
[
"
CPUExecutionProvider
"
])
def
to_numpy
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
if
tensor
.
requires_grad
else
tensor
.
cpu
().
numpy
()
for
_
in
range
(
10000
):
x
=
torch
.
rand
(
1
,
3
,
device
=
device
)
torch_out
=
model
(
x
)
ort_inputs
=
{
ort_session
.
get_inputs
()[
0
].
name
:
to_numpy
(
x
)}
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_out
),
ort_outs
[
0
],
rtol
=
1e-03
,
atol
=
1e-05
)
logger
.
info
(
"
Exported model has been tested with ONNXRuntime, and the result looks good!
"
)
if
__name__
==
"
__main__
"
:
train
(
STATE_PATH
)
test
(
STATE_PATH
)
export_model
(
STATE_PATH
,
ONNX_PATH
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment