diff --git a/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..a8bdf9d58a5ae4b51d570bb56a876e59787214ee
Binary files /dev/null and b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx differ
diff --git a/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..05840a584c7a64534a6f02d1346d33d6eb2e27d6
Binary files /dev/null and b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx differ
diff --git a/examples/mnist/nets/dummy_nn/train.py b/examples/mnist/nets/dummy_nn/train.py
index a612bcf4489231889e2854ce13bb8252a1587ab4..73b7e9bf3d52e749f88cbca64e3eef3253d5904e 100644
--- a/examples/mnist/nets/dummy_nn/train.py
+++ b/examples/mnist/nets/dummy_nn/train.py
@@ -12,11 +12,13 @@ from torch.nn.utils import prune
 from torch.utils.data import DataLoader
 from torchvision import transforms
 
-img_size = 14
+img_size = 28
 
 SEED = 42
 STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth"
 ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx"
+BIN_STATE_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.pth"
+BIN_ONNX_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.onnx"
 PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl"
 PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx"
 FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth"
@@ -81,14 +83,35 @@ class FNN(nn.Module):
         return x
 
 
+class FNN_bin(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.flatten = nn.Flatten()
+        self.fc1 = nn.Linear(img_size * img_size, 10)
+        self.fc2 = nn.Linear(10, 10)
+        self.fc3 = nn.Linear(10, 2)
+
+    def forward(self, x):
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = F.relu(x)
+        x = self.fc3(x)
+        return x
+
+
 transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])])
 
 
-def train(state_dict):
+def train(state_dict, is_bin=False):
     trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
     trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
 
-    model = FNN().to(device)
+    if is_bin:
+        model = FNN_bin().to(device)
+    else:
+        model = FNN().to(device)
     criterion = nn.CrossEntropyLoss()
     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
@@ -96,8 +119,11 @@ def train(state_dict):
         running_loss = 0.0
         for i, data in enumerate(trainloader, 0):
             inputs, labels = data[0].to(device), data[1].to(device)
+            if is_bin:
+                labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
             optimizer.zero_grad()
             outputs = model(inputs).to(device)
+            outputs = model(inputs).to(device)
             loss = criterion(outputs, labels)
             loss.backward()
             optimizer.step()
@@ -112,7 +138,7 @@ def train(state_dict):
     torch.save(model.state_dict(), state_dict)
 
 
-def train_2_nets(cnn_path, fnn_path):
+def train_2_nets(fnn_pre_path, fnn_post_path):
     trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
     trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
 
@@ -139,16 +165,25 @@ def train_2_nets(cnn_path, fnn_path):
 
     logger.info("Finished training")
 
-    torch.save(fnn_pre.state_dict(), cnn_path)
-    torch.save(fnn_post.state_dict(), fnn_path)
+    torch.save(fnn_pre.state_dict(), fnn_pre_path)
+    torch.save(fnn_post.state_dict(), fnn_post_path)
 
 
-def test(model_path, is_state_dict=True):
-    classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+def test(model_path, is_state_dict=True, is_bin=False):
+    if is_bin:
+        classes = (0, 1)
+    else:
+        classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
     testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
     testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
 
-    if is_state_dict:
+    if is_bin and (not is_state_dict):
+        raise ValueError("Binary model required to be loaded and exported via state dicts.")
+
+    if is_bin:
+        net = FNN_bin().to(device)
+        net.load_state_dict(torch.load(model_path))
+    elif is_state_dict:
         net = FNN().to(device)
         net.load_state_dict(torch.load(model_path))
     else:
@@ -160,12 +195,16 @@ def test(model_path, is_state_dict=True):
     with torch.no_grad():
         for data in testloader:
             images, labels = data[0].to(device), data[1].to(device)
+            if is_bin:
+                labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
             outputs = net(images)
             _, predicted = torch.max(outputs.data, 1)
             total += labels.size(0)
             correct += (predicted == labels).sum().item()
 
-    if is_state_dict:
+    if is_bin:
+        logger.info(f"Accuracy of the binary network on the 10000 test images: {100 * correct // total} %")
+    elif is_state_dict:
         logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
     else:
         logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %")
@@ -176,6 +215,8 @@ def test(model_path, is_state_dict=True):
     with torch.no_grad():
         for data in testloader:
             images, labels = data[0].to(device), data[1].to(device)
+            if is_bin:
+                labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
             outputs = net(images)
             _, predictions = torch.max(outputs, 1)
             for label, prediction in zip(labels, predictions):
@@ -231,8 +272,14 @@ def test_2_nets(fnn_pre_path, fnn_post_path):
         logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
 
 
-def export_model(model_path, onnx_path, is_state_dict=True):
-    if is_state_dict:
+def export_model(model_path, onnx_path, is_state_dict=True, is_bin=False):
+    if is_bin and (not is_state_dict):
+        raise ValueError("Binary model required to be loaded and exported via state dicts.")
+
+    if is_bin:
+        model = FNN_bin().to(device)
+        model.load_state_dict(torch.load(model_path))
+    elif is_state_dict:
         model = FNN().to(device)
         model.load_state_dict(torch.load(model_path))
     else:
@@ -240,12 +287,14 @@ def export_model(model_path, onnx_path, is_state_dict=True):
     x = torch.rand(1, img_size, img_size, device=device)
 
     torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
-    if is_state_dict:
+    if is_bin:
+        logger.info("Binary model exported successfully")
+    elif is_state_dict:
         logger.info("Model exported successfully")
     else:
         logger.info("Pruned model exported successfully")
 
-    test_onnx(model_path, onnx_path, is_state_dict)
+    test_onnx(model_path, onnx_path, is_state_dict, is_bin)
 
 
 def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
@@ -266,8 +315,14 @@ def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
     test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx)
 
 
-def test_onnx(model_path, onnx_path, is_state_dict=True):
-    if is_state_dict:
+def test_onnx(model_path, onnx_path, is_state_dict=True, is_bin=False):
+    if is_bin and (not is_state_dict):
+        raise ValueError("Binary model required to be loaded and exported via state dicts.")
+
+    if is_bin:
+        model = FNN_bin().to(device)
+        model.load_state_dict(torch.load(model_path))
+    elif is_state_dict:
         model = FNN().to(device)
         model.load_state_dict(torch.load(model_path))
     else:
@@ -290,7 +345,9 @@ def test_onnx(model_path, onnx_path, is_state_dict=True):
 
         np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
 
-    if is_state_dict:
+    if is_bin:
+        logger.info("Exported binary model has been tested with ONNXRuntime, and the result looks good!")
+    elif is_state_dict:
         logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
     else:
         logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!")
@@ -374,8 +431,9 @@ def build_pruned_model(state_dict, pruned_model_path):
 
 
 if __name__ == "__main__":
-    logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
+    # TODO: Add an argument parser?
 
+    logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
     train(STATE_PATH)
     test(STATE_PATH)
     export_model(STATE_PATH, ONNX_PATH)
@@ -384,7 +442,11 @@ if __name__ == "__main__":
     export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False)
 
     logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.")
-
     train_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
     test_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
     export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)
+
+    logger.info("Building, training and exporting to ONNX a binary FNN.")
+    train(BIN_STATE_PATH, is_bin=True)
+    test(BIN_STATE_PATH, is_bin=True)
+    export_model(BIN_STATE_PATH, BIN_ONNX_PATH, is_bin=True)