Skip to content
Snippets Groups Projects
Commit 3f09cc60 authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

[exps] Add new models and training script

parent f623a6bf
No related branches found
No related tags found
No related merge requests found
File added
File added
loguru
numpy
onnx
onnxruntime
torch
torchaudio
torchvision
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.onnx
import torch.optim as optim
import torchvision
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import prune
from torch.utils.data import DataLoader
from torchvision import transforms
img_size = 14
SEED = 42
STATE_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.pth"
ONNX_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.onnx"
PRUNED_MODEL_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.pkl"
PRUNED_ONNX_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.onnx"
FNN_PRE_PATH = f"fnn_pre_s{SEED}.pth"
FNN_PRE_ONNX_PATH = f"fnn_pre_s{SEED}.onnx"
FNN_POST_PATH = f"fnn_post_s{SEED}.pth"
FNN_POST_ONNX_PATH = f"fnn_post_s{SEED}.onnx"
torch.manual_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device {device}")
num_epoch = 2
batch_size = 4
class FNN_pre(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 512)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class FNN_post(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
class FNN(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(img_size * img_size, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])])
def train(state_dict):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
model = FNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs).to(device)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(model.state_dict(), state_dict)
def train_2_nets(cnn_path, fnn_path):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
fnn_pre = FNN_pre().to(device)
fnn_post = FNN_post().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(fnn_post.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs_pre = fnn_pre(inputs).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
loss = criterion(outputs_post, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(fnn_pre.state_dict(), cnn_path)
torch.save(fnn_post.state_dict(), fnn_path)
def test(model_path, is_state_dict=True):
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
if is_state_dict:
net = FNN().to(device)
net.load_state_dict(torch.load(model_path))
else:
net = torch.load(model_path)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if is_state_dict:
logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
else:
logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %")
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
_, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
def test_2_nets(fnn_pre_path, fnn_post_path):
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs_pre = fnn_pre(images).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
_, predicted = torch.max(outputs_post.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs_pre = fnn_pre(images).to(device)
outputs_post = fnn_post(outputs_pre).to(device)
_, predictions = torch.max(outputs_post, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
def export_model(model_path, onnx_path, is_state_dict=True):
if is_state_dict:
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
else:
model = torch.load(model_path)
x = torch.rand(1, img_size, img_size, device=device)
torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
if is_state_dict:
logger.info("Model exported successfully")
else:
logger.info("Pruned model exported successfully")
test_onnx(model_path, onnx_path, is_state_dict)
def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
x = torch.rand(1, 28, 28, device=device)
torch.onnx.export(model=fnn_pre, args=x, f=fnn_pre_onnx, export_params=True)
logger.info("First model exported successfully")
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
y = fnn_pre(x).to(device)
torch.onnx.export(model=fnn_post, args=y, f=fnn_post_onnx, export_params=True)
logger.info("Second model exported successfully")
test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx)
def test_onnx(model_path, onnx_path, is_state_dict=True):
if is_state_dict:
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
else:
model = torch.load(model_path)
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, img_size, img_size, device=device)
torch_out = model(x)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
if is_state_dict:
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
else:
logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!")
def test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
fnn_pre = FNN_pre().to(device)
fnn_pre.load_state_dict(torch.load(fnn_pre_path))
fnn_post = FNN_post().to(device)
fnn_post.load_state_dict(torch.load(fnn_post_path))
onnx_model_pre = onnx.load(fnn_pre_onnx)
onnx.checker.check_model(onnx_model_pre)
onnx_model_post = onnx.load(fnn_post_onnx)
onnx.checker.check_model(onnx_model_post)
ort_session_pre = ort.InferenceSession(fnn_pre_onnx, providers=["CPUExecutionProvider"])
ort_session_post = ort.InferenceSession(fnn_post_onnx, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, 28, 28, device=device)
torch_out_pre = fnn_pre(x)
ort_inputs_pre = {ort_session_pre.get_inputs()[0].name: to_numpy(x)}
ort_outs_pre = ort_session_pre.run(None, ort_inputs_pre)
np.testing.assert_allclose(to_numpy(torch_out_pre), ort_outs_pre[0], rtol=1e-03, atol=1e-05)
logger.info("First exported model has been tested with ONNXRuntime, and the result looks good!")
for _ in range(10000):
x = torch.rand(1, 512, device=device)
torch_out_post = fnn_post(x)
ort_inputs_post = {ort_session_post.get_inputs()[0].name: to_numpy(x)}
ort_outs_post = ort_session_post.run(None, ort_inputs_post)
np.testing.assert_allclose(to_numpy(torch_out_post), ort_outs_post[0], rtol=1e-03, atol=1e-05)
logger.info("Second exported model has been tested with ONNXRuntime, and the result looks good!")
def build_pruned_model(state_dict, pruned_model_path):
model = FNN().to(device)
model.load_state_dict(torch.load(state_dict))
parameters_to_prune = ((model.fc1, "weight"), (model.fc2, "weight"), (model.fc3, "weight"))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)
logger.info(
"Sparsity in fc1.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement())
)
)
logger.info(
"Sparsity in fc2.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement())
)
)
logger.info(
"Sparsity in fc3.weight: {:.2f}%".format(
100.0 * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement())
)
)
logger.info(
"Global sparsity: {:.2f}%".format(
100.0
* float(
torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0)
)
/ float(model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement())
)
)
logger.info("Finished pruning")
torch.save(model, pruned_model_path)
if __name__ == "__main__":
logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
train(STATE_PATH)
test(STATE_PATH)
export_model(STATE_PATH, ONNX_PATH)
build_pruned_model(STATE_PATH, PRUNED_MODEL_PATH)
test(PRUNED_MODEL_PATH, is_state_dict=False)
export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False)
logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.")
train_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
test_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment