diff --git a/examples/mnist/nets/dummy_nn/FNN_s42.onnx b/examples/mnist/nets/dummy_nn/FNN_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e8505c08400e5465e0fdc56e11a0d66ed5c9327f Binary files /dev/null and b/examples/mnist/nets/dummy_nn/FNN_s42.onnx differ diff --git a/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx b/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a32cd072b7606a731da5e6dd186ad9c014abe541 Binary files /dev/null and b/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx differ diff --git a/examples/mnist/nets/dummy_nn/requirements.txt b/examples/mnist/nets/dummy_nn/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fb1bde4ebee5bdcf8c2ccd1b4e1aa07974b6cfb --- /dev/null +++ b/examples/mnist/nets/dummy_nn/requirements.txt @@ -0,0 +1,7 @@ +loguru +numpy +onnx +onnxruntime +torch +torchaudio +torchvision diff --git a/examples/mnist/nets/dummy_nn/train.py b/examples/mnist/nets/dummy_nn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8505c9d8ebd7d7fa60767ecef1267938c617ab --- /dev/null +++ b/examples/mnist/nets/dummy_nn/train.py @@ -0,0 +1,390 @@ +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.onnx +import torch.optim as optim +import torchvision +from loguru import logger +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import prune +from torch.utils.data import DataLoader +from torchvision import transforms + +img_size = 14 + +SEED = 42 +STATE_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.pth" +ONNX_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.onnx" +PRUNED_MODEL_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.pkl" +PRUNED_ONNX_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.onnx" +FNN_PRE_PATH = f"fnn_pre_s{SEED}.pth" +FNN_PRE_ONNX_PATH = f"fnn_pre_s{SEED}.onnx" +FNN_POST_PATH = f"fnn_post_s{SEED}.pth" +FNN_POST_ONNX_PATH = f"fnn_post_s{SEED}.onnx" + +torch.manual_seed(SEED) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device {device}") +num_epoch = 2 +batch_size = 4 + + +class FNN_pre(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(28 * 28, 512) + self.fc2 = nn.Linear(512, 512) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + return x + + +class FNN_post(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(512, 256) + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + +class FNN(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(img_size * img_size, 10) + self.fc2 = nn.Linear(10, 10) + self.fc3 = nn.Linear(10, 10) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + +transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])]) + + +def train(state_dict): + trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + model = FNN().to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs).to(device) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(model.state_dict(), state_dict) + + +def train_2_nets(cnn_path, fnn_path): + trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor()) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + fnn_pre = FNN_pre().to(device) + fnn_post = FNN_post().to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(fnn_post.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs_pre = fnn_pre(inputs).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + loss = criterion(outputs_post, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(fnn_pre.state_dict(), cnn_path) + torch.save(fnn_post.state_dict(), fnn_path) + + +def test(model_path, is_state_dict=True): + classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + if is_state_dict: + net = FNN().to(device) + net.load_state_dict(torch.load(model_path)) + else: + net = torch.load(model_path) + + correct = 0 + total = 0 + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs = net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + if is_state_dict: + logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + else: + logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %") + + correct_pred = {classname: 0 for classname in classes} + total_pred = {classname: 0 for classname in classes} + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs = net(images) + _, predictions = torch.max(outputs, 1) + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") + + +def test_2_nets(fnn_pre_path, fnn_post_path): + classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor()) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + + correct = 0 + total = 0 + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs_pre = fnn_pre(images).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + _, predicted = torch.max(outputs_post.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + + correct_pred = {classname: 0 for classname in classes} + total_pred = {classname: 0 for classname in classes} + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs_pre = fnn_pre(images).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + _, predictions = torch.max(outputs_post, 1) + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") + + +def export_model(model_path, onnx_path, is_state_dict=True): + if is_state_dict: + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + else: + model = torch.load(model_path) + x = torch.rand(1, img_size, img_size, device=device) + + torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) + if is_state_dict: + logger.info("Model exported successfully") + else: + logger.info("Pruned model exported successfully") + + test_onnx(model_path, onnx_path, is_state_dict) + + +def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + x = torch.rand(1, 28, 28, device=device) + + torch.onnx.export(model=fnn_pre, args=x, f=fnn_pre_onnx, export_params=True) + logger.info("First model exported successfully") + + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + y = fnn_pre(x).to(device) + + torch.onnx.export(model=fnn_post, args=y, f=fnn_post_onnx, export_params=True) + logger.info("Second model exported successfully") + + test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx) + + +def test_onnx(model_path, onnx_path, is_state_dict=True): + if is_state_dict: + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + else: + model = torch.load(model_path) + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, img_size, img_size, device=device) + torch_out = model(x) + + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + if is_state_dict: + logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") + else: + logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!") + + +def test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + + onnx_model_pre = onnx.load(fnn_pre_onnx) + onnx.checker.check_model(onnx_model_pre) + onnx_model_post = onnx.load(fnn_post_onnx) + onnx.checker.check_model(onnx_model_post) + + ort_session_pre = ort.InferenceSession(fnn_pre_onnx, providers=["CPUExecutionProvider"]) + ort_session_post = ort.InferenceSession(fnn_post_onnx, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, 28, 28, device=device) + torch_out_pre = fnn_pre(x) + + ort_inputs_pre = {ort_session_pre.get_inputs()[0].name: to_numpy(x)} + ort_outs_pre = ort_session_pre.run(None, ort_inputs_pre) + + np.testing.assert_allclose(to_numpy(torch_out_pre), ort_outs_pre[0], rtol=1e-03, atol=1e-05) + + logger.info("First exported model has been tested with ONNXRuntime, and the result looks good!") + + for _ in range(10000): + x = torch.rand(1, 512, device=device) + torch_out_post = fnn_post(x) + + ort_inputs_post = {ort_session_post.get_inputs()[0].name: to_numpy(x)} + ort_outs_post = ort_session_post.run(None, ort_inputs_post) + + np.testing.assert_allclose(to_numpy(torch_out_post), ort_outs_post[0], rtol=1e-03, atol=1e-05) + + logger.info("Second exported model has been tested with ONNXRuntime, and the result looks good!") + + +def build_pruned_model(state_dict, pruned_model_path): + model = FNN().to(device) + model.load_state_dict(torch.load(state_dict)) + + parameters_to_prune = ((model.fc1, "weight"), (model.fc2, "weight"), (model.fc3, "weight")) + prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2) + + logger.info( + "Sparsity in fc1.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement()) + ) + ) + logger.info( + "Sparsity in fc2.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement()) + ) + ) + logger.info( + "Sparsity in fc3.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement()) + ) + ) + logger.info( + "Global sparsity: {:.2f}%".format( + 100.0 + * float( + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) + ) + / float(model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement()) + ) + ) + + logger.info("Finished pruning") + + torch.save(model, pruned_model_path) + + +if __name__ == "__main__": + logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.") + + train(STATE_PATH) + test(STATE_PATH) + export_model(STATE_PATH, ONNX_PATH) + build_pruned_model(STATE_PATH, PRUNED_MODEL_PATH) + test(PRUNED_MODEL_PATH, is_state_dict=False) + export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False) + + logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.") + + train_2_nets(FNN_PRE_PATH, FNN_POST_PATH) + test_2_nets(FNN_PRE_PATH, FNN_POST_PATH) + export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)