diff --git a/examples/arithmetic/FNN_s42.onnx b/examples/arithmetic/FNN_s42.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..94e1c0b8837b0de4e0369d54291549fb039cc3f1
Binary files /dev/null and b/examples/arithmetic/FNN_s42.onnx differ
diff --git a/examples/arithmetic/train.py b/examples/arithmetic/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2661fc21f5e592cae803d29bee9a62e72c5e3b0e
--- /dev/null
+++ b/examples/arithmetic/train.py
@@ -0,0 +1,148 @@
+import os
+
+import numpy as np
+import onnx
+import onnxruntime as ort
+import torch
+import torch.onnx
+import torch.optim as optim
+from loguru import logger
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.data import DataLoader, Dataset
+
+SEED = 42
+STATE_PATH = f"FNN_s{SEED}.pth"
+ONNX_PATH = f"FNN_s{SEED}.onnx"
+INPUT_ARRAY = "data.npy"
+TEST_INPUT_ARRAY = "test_data.npy"
+TARGET_ARRAY = "target.npy"
+TEST_TARGET_ARRAY = "test_target.npy"
+
+torch.manual_seed(SEED)
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+logger.info(f"Using device {device}")
+num_epoch = 2
+batch_size = 4
+
+
+class FNN(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.fc1 = nn.Linear(3, 128)
+        self.fc2 = nn.Linear(128, 128)
+        self.fc3 = nn.Linear(128, 1)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = F.relu(x)
+        x = self.fc3(x)
+        return x
+
+
+class ArithmeticDataset(Dataset):
+    def __init__(self, input_array, target_array, root_dir):
+        self.input_array = np.load(input_array).astype(np.float32)
+        self.target_array = np.load(target_array).astype(np.float32)
+        self.root_dir = root_dir
+
+    def __len__(self):
+        return len(self.input_array)
+
+    def __getitem__(self, idx):
+        return [self.input_array[idx], self.target_array[idx]]
+
+
+def train(state_dict):
+    trainset = ArithmeticDataset(
+        input_array=INPUT_ARRAY, target_array=TARGET_ARRAY, root_dir=os.path.dirname(INPUT_ARRAY)
+    )
+    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
+
+    model = FNN().to(device)
+    criterion = nn.MSELoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+    for epoch in range(num_epoch):
+        running_loss = 0.0
+        for i, data in enumerate(trainloader):
+            inputs, labels = data[0].to(device), data[1].to(device)
+            optimizer.zero_grad()
+            outputs = model(inputs).squeeze().to(device)
+            loss = criterion(outputs, labels)
+            loss.backward()
+            optimizer.step()
+
+            running_loss += loss.item()
+            if i % 2000 == 1999:
+                logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
+                running_loss = 0.0
+
+    logger.info("Finished training")
+
+    torch.save(model.state_dict(), state_dict)
+
+
+def test(model_path):
+    testset = ArithmeticDataset(
+        input_array=TEST_INPUT_ARRAY, target_array=TEST_TARGET_ARRAY, root_dir=os.path.dirname(TEST_INPUT_ARRAY)
+    )
+    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
+
+    net = FNN().to(device)
+    net.load_state_dict(torch.load(model_path))
+
+    error = 0
+    total = len(testloader) * batch_size
+
+    with torch.no_grad():
+        for data in testloader:
+            inputs, labels = data[0].to(device), data[1].to(device)
+            outputs = net(inputs).squeeze().to(device)
+            error += ((outputs - labels) * (outputs - labels)).sum().data.cpu()
+
+    logger.info(f"Average MSE of the network on the 10000 test inputs: {np.sqrt(error / total):.3f}")
+
+
+def export_model(model_path, onnx_path):
+    model = FNN().to(device)
+    model.load_state_dict(torch.load(model_path))
+    x = torch.rand(1, 3, device=device)
+
+    torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
+    logger.info("Model exported successfully")
+
+    test_onnx(model_path, onnx_path)
+
+
+def test_onnx(model_path, onnx_path):
+    model = FNN().to(device)
+    model.load_state_dict(torch.load(model_path))
+
+    onnx_model = onnx.load(onnx_path)
+    onnx.checker.check_model(onnx_model)
+
+    ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
+
+    def to_numpy(tensor):
+        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+    for _ in range(10000):
+        x = torch.rand(1, 3, device=device)
+        torch_out = model(x)
+
+        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
+        ort_outs = ort_session.run(None, ort_inputs)
+
+        np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
+
+    logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
+
+
+if __name__ == "__main__":
+    train(STATE_PATH)
+    test(STATE_PATH)
+    export_model(STATE_PATH, ONNX_PATH)