-
Aymeric Varasse authoredAymeric Varasse authored
train.py 16.05 KiB
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.onnx
import torch.optim as optim
import torchvision
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import prune
from torch.utils.data import DataLoader
from torchvision import transforms
img_size = 28
SEED = 42
STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth"
ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx"
BIN_STATE_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.pth"
BIN_ONNX_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.onnx"
PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl"
PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx"
FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth"
FNN_PRE_ONNX_PATH = f"FNN_28x28_pre_s{SEED}.onnx"
FNN_POST_PATH = f"FNN_28x28_post_s{SEED}.pth"
FNN_POST_ONNX_PATH = f"FNN_28x28_post_s{SEED}.onnx"
torch.manual_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device {device}")
num_epoch = 2
batch_size = 4
class FNN_pre(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 512)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class FNN_post(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
class FNN(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(img_size * img_size, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
class FNN_bin(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(img_size * img_size, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])])
def train(state_dict, is_bin=False):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
if is_bin:
model = FNN_bin().to(device)
else:
model = FNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
optimizer.zero_grad()
outputs = model(inputs).to(device)
outputs = model(inputs).to(device)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(model.state_dict(), state_dict)
def train_2_nets(fnn_pre_path, fnn_post_path):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
fnn_pre = FNN_pre().to(device)
fnn_post = FNN_post().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(fnn_post.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs_pre = fnn_pre(inputs).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
loss = criterion(outputs_post, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(fnn_pre.state_dict(), fnn_pre_path)
torch.save(fnn_post.state_dict(), fnn_post_path)
def test(model_path, is_state_dict=True, is_bin=False):
if is_bin:
classes = (0, 1)
else:
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
net = FNN_bin().to(device)
net.load_state_dict(torch.load(model_path))
elif is_state_dict:
net = FNN().to(device)
net.load_state_dict(torch.load(model_path))
else:
net = torch.load(model_path)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if is_bin:
logger.info(f"Accuracy of the binary network on the 10000 test images: {100 * correct // total} %")
elif is_state_dict:
logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
else:
logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %")
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
outputs = net(images)
_, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
def test_2_nets(fnn_pre_path, fnn_post_path):
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs_pre = fnn_pre(images).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
_, predicted = torch.max(outputs_post.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs_pre = fnn_pre(images).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
_, predictions = torch.max(outputs_post, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
def export_model(model_path, onnx_path, is_state_dict=True, is_bin=False):
if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
model = FNN_bin().to(device)
model.load_state_dict(torch.load(model_path))
elif is_state_dict:
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
else:
model = torch.load(model_path)
x = torch.rand(1, img_size, img_size, device=device)
torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
if is_bin:
logger.info("Binary model exported successfully")
elif is_state_dict:
logger.info("Model exported successfully")
else:
logger.info("Pruned model exported successfully")
test_onnx(model_path, onnx_path, is_state_dict, is_bin)
def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
x = torch.rand(1, 28, 28, device=device)
torch.onnx.export(model=fnn_pre, args=x, f=fnn_pre_onnx, export_params=True)
logger.info("First model exported successfully")
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
y = fnn_pre(x).to(device)
torch.onnx.export(model=fnn_post, args=y, f=fnn_post_onnx, export_params=True)
logger.info("Second model exported successfully")
test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx)
def test_onnx(model_path, onnx_path, is_state_dict=True, is_bin=False):
if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
model = FNN_bin().to(device)
model.load_state_dict(torch.load(model_path))
elif is_state_dict:
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
else:
model = torch.load(model_path)
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, img_size, img_size, device=device)
torch_out = model(x)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
if is_bin:
logger.info("Exported binary model has been tested with ONNXRuntime, and the result looks good!")
elif is_state_dict:
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
else:
logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!")
def test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
onnx_model_pre = onnx.load(fnn_pre_onnx)
onnx.checker.check_model(onnx_model_pre)
onnx_model_post = onnx.load(fnn_post_onnx)
onnx.checker.check_model(onnx_model_post)
ort_session_pre = ort.InferenceSession(fnn_pre_onnx, providers=["CPUExecutionProvider"])
ort_session_post = ort.InferenceSession(fnn_post_onnx, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, 28, 28, device=device)
torch_out_pre = fnn_pre(x)
ort_inputs_pre = {ort_session_pre.get_inputs()[0].name: to_numpy(x)}
ort_outs_pre = ort_session_pre.run(None, ort_inputs_pre)
np.testing.assert_allclose(to_numpy(torch_out_pre), ort_outs_pre[0], rtol=1e-03, atol=1e-05)
logger.info("First exported model has been tested with ONNXRuntime, and the result looks good!")
for _ in range(10000):
x = torch.rand(1, 512, device=device)
torch_out_post = fnn_post(x)
ort_inputs_post = {ort_session_post.get_inputs()[0].name: to_numpy(x)}
ort_outs_post = ort_session_post.run(None, ort_inputs_post)
np.testing.assert_allclose(to_numpy(torch_out_post), ort_outs_post[0], rtol=1e-03, atol=1e-05)
logger.info("Second exported model has been tested with ONNXRuntime, and the result looks good!")
def build_pruned_model(state_dict, pruned_model_path):
model = FNN().to(device)
model.load_state_dict(torch.load(state_dict))
parameters_to_prune = ((model.fc1, "weight"), (model.fc2, "weight"), (model.fc3, "weight"))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)
logger.info(
"Sparsity in fc1.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement())
)
)
logger.info(
"Sparsity in fc2.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement())
)
)
logger.info(
"Sparsity in fc3.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement())
)
)
logger.info(
"Global sparsity: {:.2f}%".format(
100.0
* float(
torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0)
)
/ float(model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement())
)
)
logger.info("Finished pruning")
torch.save(model, pruned_model_path)
if __name__ == "__main__":
# TODO: Add an argument parser?
logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
train(STATE_PATH)
test(STATE_PATH)
export_model(STATE_PATH, ONNX_PATH)
build_pruned_model(STATE_PATH, PRUNED_MODEL_PATH)
test(PRUNED_MODEL_PATH, is_state_dict=False)
export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False)
logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.")
train_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
test_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)
logger.info("Building, training and exporting to ONNX a binary FNN.")
train(BIN_STATE_PATH, is_bin=True)
test(BIN_STATE_PATH, is_bin=True)
export_model(BIN_STATE_PATH, BIN_ONNX_PATH, is_bin=True)