diff --git a/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a8bdf9d58a5ae4b51d570bb56a876e59787214ee Binary files /dev/null and b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx differ diff --git a/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05840a584c7a64534a6f02d1346d33d6eb2e27d6 Binary files /dev/null and b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx differ diff --git a/examples/mnist/nets/dummy_nn/train.py b/examples/mnist/nets/dummy_nn/train.py index a612bcf4489231889e2854ce13bb8252a1587ab4..73b7e9bf3d52e749f88cbca64e3eef3253d5904e 100644 --- a/examples/mnist/nets/dummy_nn/train.py +++ b/examples/mnist/nets/dummy_nn/train.py @@ -12,11 +12,13 @@ from torch.nn.utils import prune from torch.utils.data import DataLoader from torchvision import transforms -img_size = 14 +img_size = 28 SEED = 42 STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth" ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx" +BIN_STATE_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.pth" +BIN_ONNX_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.onnx" PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl" PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx" FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth" @@ -81,14 +83,35 @@ class FNN(nn.Module): return x +class FNN_bin(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(img_size * img_size, 10) + self.fc2 = nn.Linear(10, 10) + self.fc3 = nn.Linear(10, 2) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])]) -def train(state_dict): +def train(state_dict, is_bin=False): trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) - model = FNN().to(device) + if is_bin: + model = FNN_bin().to(device) + else: + model = FNN().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) @@ -96,8 +119,11 @@ def train(state_dict): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) optimizer.zero_grad() outputs = model(inputs).to(device) + outputs = model(inputs).to(device) loss = criterion(outputs, labels) loss.backward() optimizer.step() @@ -112,7 +138,7 @@ def train(state_dict): torch.save(model.state_dict(), state_dict) -def train_2_nets(cnn_path, fnn_path): +def train_2_nets(fnn_pre_path, fnn_post_path): trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor()) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) @@ -139,16 +165,25 @@ def train_2_nets(cnn_path, fnn_path): logger.info("Finished training") - torch.save(fnn_pre.state_dict(), cnn_path) - torch.save(fnn_post.state_dict(), fnn_path) + torch.save(fnn_pre.state_dict(), fnn_pre_path) + torch.save(fnn_post.state_dict(), fnn_post_path) -def test(model_path, is_state_dict=True): - classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) +def test(model_path, is_state_dict=True, is_bin=False): + if is_bin: + classes = (0, 1) + else: + classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) - if is_state_dict: + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + net = FNN_bin().to(device) + net.load_state_dict(torch.load(model_path)) + elif is_state_dict: net = FNN().to(device) net.load_state_dict(torch.load(model_path)) else: @@ -160,12 +195,16 @@ def test(model_path, is_state_dict=True): with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() - if is_state_dict: + if is_bin: + logger.info(f"Accuracy of the binary network on the 10000 test images: {100 * correct // total} %") + elif is_state_dict: logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") else: logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %") @@ -176,6 +215,8 @@ def test(model_path, is_state_dict=True): with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) outputs = net(images) _, predictions = torch.max(outputs, 1) for label, prediction in zip(labels, predictions): @@ -231,8 +272,14 @@ def test_2_nets(fnn_pre_path, fnn_post_path): logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") -def export_model(model_path, onnx_path, is_state_dict=True): - if is_state_dict: +def export_model(model_path, onnx_path, is_state_dict=True, is_bin=False): + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + model = FNN_bin().to(device) + model.load_state_dict(torch.load(model_path)) + elif is_state_dict: model = FNN().to(device) model.load_state_dict(torch.load(model_path)) else: @@ -240,12 +287,14 @@ def export_model(model_path, onnx_path, is_state_dict=True): x = torch.rand(1, img_size, img_size, device=device) torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) - if is_state_dict: + if is_bin: + logger.info("Binary model exported successfully") + elif is_state_dict: logger.info("Model exported successfully") else: logger.info("Pruned model exported successfully") - test_onnx(model_path, onnx_path, is_state_dict) + test_onnx(model_path, onnx_path, is_state_dict, is_bin) def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): @@ -266,8 +315,14 @@ def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx) -def test_onnx(model_path, onnx_path, is_state_dict=True): - if is_state_dict: +def test_onnx(model_path, onnx_path, is_state_dict=True, is_bin=False): + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + model = FNN_bin().to(device) + model.load_state_dict(torch.load(model_path)) + elif is_state_dict: model = FNN().to(device) model.load_state_dict(torch.load(model_path)) else: @@ -290,7 +345,9 @@ def test_onnx(model_path, onnx_path, is_state_dict=True): np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) - if is_state_dict: + if is_bin: + logger.info("Exported binary model has been tested with ONNXRuntime, and the result looks good!") + elif is_state_dict: logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") else: logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!") @@ -374,8 +431,9 @@ def build_pruned_model(state_dict, pruned_model_path): if __name__ == "__main__": - logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.") + # TODO: Add an argument parser? + logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.") train(STATE_PATH) test(STATE_PATH) export_model(STATE_PATH, ONNX_PATH) @@ -384,7 +442,11 @@ if __name__ == "__main__": export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False) logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.") - train_2_nets(FNN_PRE_PATH, FNN_POST_PATH) test_2_nets(FNN_PRE_PATH, FNN_POST_PATH) export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH) + + logger.info("Building, training and exporting to ONNX a binary FNN.") + train(BIN_STATE_PATH, is_bin=True) + test(BIN_STATE_PATH, is_bin=True) + export_model(BIN_STATE_PATH, BIN_ONNX_PATH, is_bin=True)