Skip to content
Snippets Groups Projects
Commit 1feb8ad7 authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

[exps] Add onnx and training script for arithmetic

parent 57128258
No related branches found
No related tags found
No related merge requests found
File added
import os
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.onnx
import torch.optim as optim
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
SEED = 42
STATE_PATH = f"FNN_s{SEED}.pth"
ONNX_PATH = f"FNN_s{SEED}.onnx"
INPUT_ARRAY = "data.npy"
TEST_INPUT_ARRAY = "test_data.npy"
TARGET_ARRAY = "target.npy"
TEST_TARGET_ARRAY = "test_target.npy"
torch.manual_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device {device}")
num_epoch = 2
batch_size = 4
class FNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
class ArithmeticDataset(Dataset):
def __init__(self, input_array, target_array, root_dir):
self.input_array = np.load(input_array).astype(np.float32)
self.target_array = np.load(target_array).astype(np.float32)
self.root_dir = root_dir
def __len__(self):
return len(self.input_array)
def __getitem__(self, idx):
return [self.input_array[idx], self.target_array[idx]]
def train(state_dict):
trainset = ArithmeticDataset(
input_array=INPUT_ARRAY, target_array=TARGET_ARRAY, root_dir=os.path.dirname(INPUT_ARRAY)
)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
model = FNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs).squeeze().to(device)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(model.state_dict(), state_dict)
def test(model_path):
testset = ArithmeticDataset(
input_array=TEST_INPUT_ARRAY, target_array=TEST_TARGET_ARRAY, root_dir=os.path.dirname(TEST_INPUT_ARRAY)
)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
net = FNN().to(device)
net.load_state_dict(torch.load(model_path))
error = 0
total = len(testloader) * batch_size
with torch.no_grad():
for data in testloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs).squeeze().to(device)
error += ((outputs - labels) * (outputs - labels)).sum().data.cpu()
logger.info(f"Average MSE of the network on the 10000 test inputs: {np.sqrt(error / total):.3f}")
def export_model(model_path, onnx_path):
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
x = torch.rand(1, 3, device=device)
torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
logger.info("Model exported successfully")
test_onnx(model_path, onnx_path)
def test_onnx(model_path, onnx_path):
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, 3, device=device)
torch_out = model(x)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
if __name__ == "__main__":
train(STATE_PATH)
test(STATE_PATH)
export_model(STATE_PATH, ONNX_PATH)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment