diff --git a/examples/mnist/nets/dummy_nn/FNN_s42.onnx b/examples/mnist/nets/dummy_nn/FNN_s42.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..e8505c08400e5465e0fdc56e11a0d66ed5c9327f
Binary files /dev/null and b/examples/mnist/nets/dummy_nn/FNN_s42.onnx differ
diff --git a/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx b/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..a32cd072b7606a731da5e6dd186ad9c014abe541
Binary files /dev/null and b/examples/mnist/nets/dummy_nn/pruned_FNN_s42.onnx differ
diff --git a/examples/mnist/nets/dummy_nn/requirements.txt b/examples/mnist/nets/dummy_nn/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8fb1bde4ebee5bdcf8c2ccd1b4e1aa07974b6cfb
--- /dev/null
+++ b/examples/mnist/nets/dummy_nn/requirements.txt
@@ -0,0 +1,7 @@
+loguru
+numpy
+onnx
+onnxruntime
+torch
+torchaudio
+torchvision
diff --git a/examples/mnist/nets/dummy_nn/train.py b/examples/mnist/nets/dummy_nn/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d8505c9d8ebd7d7fa60767ecef1267938c617ab
--- /dev/null
+++ b/examples/mnist/nets/dummy_nn/train.py
@@ -0,0 +1,390 @@
+import numpy as np
+import onnx
+import onnxruntime as ort
+import torch
+import torch.onnx
+import torch.optim as optim
+import torchvision
+from loguru import logger
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import prune
+from torch.utils.data import DataLoader
+from torchvision import transforms
+
+img_size = 14
+
+SEED = 42
+STATE_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.pth"
+ONNX_PATH = f"{img_size}x{img_size}_FNN_s{SEED}.onnx"
+PRUNED_MODEL_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.pkl"
+PRUNED_ONNX_PATH = f"pruned_{img_size}x{img_size}_FNN_s{SEED}.onnx"
+FNN_PRE_PATH = f"fnn_pre_s{SEED}.pth"
+FNN_PRE_ONNX_PATH = f"fnn_pre_s{SEED}.onnx"
+FNN_POST_PATH = f"fnn_post_s{SEED}.pth"
+FNN_POST_ONNX_PATH = f"fnn_post_s{SEED}.onnx"
+
+torch.manual_seed(SEED)
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+logger.info(f"Using device {device}")
+num_epoch = 2
+batch_size = 4
+
+
+class FNN_pre(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.flatten = nn.Flatten()
+        self.fc1 = nn.Linear(28 * 28, 512)
+        self.fc2 = nn.Linear(512, 512)
+
+    def forward(self, x):
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = F.relu(x)
+        return x
+
+
+class FNN_post(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.flatten = nn.Flatten()
+        self.fc1 = nn.Linear(512, 256)
+        self.fc2 = nn.Linear(256, 10)
+
+    def forward(self, x):
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        return x
+
+
+class FNN(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.flatten = nn.Flatten()
+        self.fc1 = nn.Linear(img_size * img_size, 10)
+        self.fc2 = nn.Linear(10, 10)
+        self.fc3 = nn.Linear(10, 10)
+
+    def forward(self, x):
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = F.relu(x)
+        x = self.fc3(x)
+        return x
+
+
+transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])])
+
+
+def train(state_dict):
+    trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
+    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
+
+    model = FNN().to(device)
+    criterion = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+    for epoch in range(num_epoch):
+        running_loss = 0.0
+        for i, data in enumerate(trainloader, 0):
+            inputs, labels = data[0].to(device), data[1].to(device)
+            optimizer.zero_grad()
+            outputs = model(inputs).to(device)
+            loss = criterion(outputs, labels)
+            loss.backward()
+            optimizer.step()
+
+            running_loss += loss.item()
+            if i % 2000 == 1999:  # print every 2000 mini-batches
+                logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
+                running_loss = 0.0
+
+    logger.info("Finished training")
+
+    torch.save(model.state_dict(), state_dict)
+
+
+def train_2_nets(cnn_path, fnn_path):
+    trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
+    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
+
+    fnn_pre = FNN_pre().to(device)
+    fnn_post = FNN_post().to(device)
+    criterion = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(fnn_post.parameters(), lr=0.001, momentum=0.9)
+
+    for epoch in range(num_epoch):
+        running_loss = 0.0
+        for i, data in enumerate(trainloader, 0):
+            inputs, labels = data[0].to(device), data[1].to(device)
+            optimizer.zero_grad()
+            outputs_pre = fnn_pre(inputs).to(device)
+            outputs_post = fnn_post(outputs_pre).to(device)
+            loss = criterion(outputs_post, labels)
+            loss.backward()
+            optimizer.step()
+
+            running_loss += loss.item()
+            if i % 2000 == 1999:
+                logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
+                running_loss = 0.0
+
+    logger.info("Finished training")
+
+    torch.save(fnn_pre.state_dict(), cnn_path)
+    torch.save(fnn_post.state_dict(), fnn_path)
+
+
+def test(model_path, is_state_dict=True):
+    classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+    testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
+    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
+
+    if is_state_dict:
+        net = FNN().to(device)
+        net.load_state_dict(torch.load(model_path))
+    else:
+        net = torch.load(model_path)
+
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for data in testloader:
+            images, labels = data[0].to(device), data[1].to(device)
+            outputs = net(images)
+            _, predicted = torch.max(outputs.data, 1)
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+    if is_state_dict:
+        logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
+    else:
+        logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %")
+
+    correct_pred = {classname: 0 for classname in classes}
+    total_pred = {classname: 0 for classname in classes}
+
+    with torch.no_grad():
+        for data in testloader:
+            images, labels = data[0].to(device), data[1].to(device)
+            outputs = net(images)
+            _, predictions = torch.max(outputs, 1)
+            for label, prediction in zip(labels, predictions):
+                if label == prediction:
+                    correct_pred[classes[label]] += 1
+                total_pred[classes[label]] += 1
+
+    for classname, correct_count in correct_pred.items():
+        accuracy = 100 * float(correct_count) / total_pred[classname]
+        logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
+
+
+def test_2_nets(fnn_pre_path, fnn_post_path):
+    classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+    testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
+    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
+
+    fnn_pre = FNN_pre().to(device)
+    fnn_pre.load_state_dict(torch.load(fnn_pre_path))
+    fnn_post = FNN_post().to(device)
+    fnn_post.load_state_dict(torch.load(fnn_post_path))
+
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for data in testloader:
+            images, labels = data[0].to(device), data[1].to(device)
+            outputs_pre = fnn_pre(images).to(device)
+            outputs_post = fnn_post(outputs_pre).to(device)
+            _, predicted = torch.max(outputs_post.data, 1)
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+    logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
+
+    correct_pred = {classname: 0 for classname in classes}
+    total_pred = {classname: 0 for classname in classes}
+
+    with torch.no_grad():
+        for data in testloader:
+            images, labels = data[0].to(device), data[1].to(device)
+            outputs_pre = fnn_pre(images).to(device)
+            outputs_post = fnn_post(outputs_pre).to(device)
+            _, predictions = torch.max(outputs_post, 1)
+            for label, prediction in zip(labels, predictions):
+                if label == prediction:
+                    correct_pred[classes[label]] += 1
+                total_pred[classes[label]] += 1
+
+    for classname, correct_count in correct_pred.items():
+        accuracy = 100 * float(correct_count) / total_pred[classname]
+        logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
+
+
+def export_model(model_path, onnx_path, is_state_dict=True):
+    if is_state_dict:
+        model = FNN().to(device)
+        model.load_state_dict(torch.load(model_path))
+    else:
+        model = torch.load(model_path)
+    x = torch.rand(1, img_size, img_size, device=device)
+
+    torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
+    if is_state_dict:
+        logger.info("Model exported successfully")
+    else:
+        logger.info("Pruned model exported successfully")
+
+    test_onnx(model_path, onnx_path, is_state_dict)
+
+
+def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
+    fnn_pre = FNN_pre().to(device)
+    fnn_pre.load_state_dict(torch.load(fnn_pre_path))
+    x = torch.rand(1, 28, 28, device=device)
+
+    torch.onnx.export(model=fnn_pre, args=x, f=fnn_pre_onnx, export_params=True)
+    logger.info("First model exported successfully")
+
+    fnn_post = FNN_post().to(device)
+    fnn_post.load_state_dict(torch.load(fnn_post_path))
+    y = fnn_pre(x).to(device)
+
+    torch.onnx.export(model=fnn_post, args=y, f=fnn_post_onnx, export_params=True)
+    logger.info("Second model exported successfully")
+
+    test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx)
+
+
+def test_onnx(model_path, onnx_path, is_state_dict=True):
+    if is_state_dict:
+        model = FNN().to(device)
+        model.load_state_dict(torch.load(model_path))
+    else:
+        model = torch.load(model_path)
+
+    onnx_model = onnx.load(onnx_path)
+    onnx.checker.check_model(onnx_model)
+
+    ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
+
+    def to_numpy(tensor):
+        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+    for _ in range(10000):
+        x = torch.rand(1, img_size, img_size, device=device)
+        torch_out = model(x)
+
+        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
+        ort_outs = ort_session.run(None, ort_inputs)
+
+        np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
+
+    if is_state_dict:
+        logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
+    else:
+        logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!")
+
+
+def test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
+    fnn_pre = FNN_pre().to(device)
+    fnn_pre.load_state_dict(torch.load(fnn_pre_path))
+    fnn_post = FNN_post().to(device)
+    fnn_post.load_state_dict(torch.load(fnn_post_path))
+
+    onnx_model_pre = onnx.load(fnn_pre_onnx)
+    onnx.checker.check_model(onnx_model_pre)
+    onnx_model_post = onnx.load(fnn_post_onnx)
+    onnx.checker.check_model(onnx_model_post)
+
+    ort_session_pre = ort.InferenceSession(fnn_pre_onnx, providers=["CPUExecutionProvider"])
+    ort_session_post = ort.InferenceSession(fnn_post_onnx, providers=["CPUExecutionProvider"])
+
+    def to_numpy(tensor):
+        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+    for _ in range(10000):
+        x = torch.rand(1, 28, 28, device=device)
+        torch_out_pre = fnn_pre(x)
+
+        ort_inputs_pre = {ort_session_pre.get_inputs()[0].name: to_numpy(x)}
+        ort_outs_pre = ort_session_pre.run(None, ort_inputs_pre)
+
+        np.testing.assert_allclose(to_numpy(torch_out_pre), ort_outs_pre[0], rtol=1e-03, atol=1e-05)
+
+    logger.info("First exported model has been tested with ONNXRuntime, and the result looks good!")
+
+    for _ in range(10000):
+        x = torch.rand(1, 512, device=device)
+        torch_out_post = fnn_post(x)
+
+        ort_inputs_post = {ort_session_post.get_inputs()[0].name: to_numpy(x)}
+        ort_outs_post = ort_session_post.run(None, ort_inputs_post)
+
+        np.testing.assert_allclose(to_numpy(torch_out_post), ort_outs_post[0], rtol=1e-03, atol=1e-05)
+
+    logger.info("Second exported model has been tested with ONNXRuntime, and the result looks good!")
+
+
+def build_pruned_model(state_dict, pruned_model_path):
+    model = FNN().to(device)
+    model.load_state_dict(torch.load(state_dict))
+
+    parameters_to_prune = ((model.fc1, "weight"), (model.fc2, "weight"), (model.fc3, "weight"))
+    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)
+
+    logger.info(
+        "Sparsity in fc1.weight: {:.2f}%".format(
+            100.0 * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement())
+        )
+    )
+    logger.info(
+        "Sparsity in fc2.weight: {:.2f}%".format(
+            100.0 * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement())
+        )
+    )
+    logger.info(
+        "Sparsity in fc3.weight: {:.2f}%".format(
+            100.0 * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement())
+        )
+    )
+    logger.info(
+        "Global sparsity: {:.2f}%".format(
+            100.0
+            * float(
+                torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0)
+            )
+            / float(model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement())
+        )
+    )
+
+    logger.info("Finished pruning")
+
+    torch.save(model, pruned_model_path)
+
+
+if __name__ == "__main__":
+    logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
+
+    train(STATE_PATH)
+    test(STATE_PATH)
+    export_model(STATE_PATH, ONNX_PATH)
+    build_pruned_model(STATE_PATH, PRUNED_MODEL_PATH)
+    test(PRUNED_MODEL_PATH, is_state_dict=False)
+    export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False)
+
+    logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.")
+
+    train_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
+    test_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
+    export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)