diff --git a/examples/arithmetic/FNN_s42.onnx b/examples/arithmetic/FNN_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..94e1c0b8837b0de4e0369d54291549fb039cc3f1 Binary files /dev/null and b/examples/arithmetic/FNN_s42.onnx differ diff --git a/examples/arithmetic/train.py b/examples/arithmetic/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2661fc21f5e592cae803d29bee9a62e72c5e3b0e --- /dev/null +++ b/examples/arithmetic/train.py @@ -0,0 +1,148 @@ +import os + +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.onnx +import torch.optim as optim +from loguru import logger +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +SEED = 42 +STATE_PATH = f"FNN_s{SEED}.pth" +ONNX_PATH = f"FNN_s{SEED}.onnx" +INPUT_ARRAY = "data.npy" +TEST_INPUT_ARRAY = "test_data.npy" +TARGET_ARRAY = "target.npy" +TEST_TARGET_ARRAY = "test_target.npy" + +torch.manual_seed(SEED) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device {device}") +num_epoch = 2 +batch_size = 4 + + +class FNN(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(3, 128) + self.fc2 = nn.Linear(128, 128) + self.fc3 = nn.Linear(128, 1) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + +class ArithmeticDataset(Dataset): + def __init__(self, input_array, target_array, root_dir): + self.input_array = np.load(input_array).astype(np.float32) + self.target_array = np.load(target_array).astype(np.float32) + self.root_dir = root_dir + + def __len__(self): + return len(self.input_array) + + def __getitem__(self, idx): + return [self.input_array[idx], self.target_array[idx]] + + +def train(state_dict): + trainset = ArithmeticDataset( + input_array=INPUT_ARRAY, target_array=TARGET_ARRAY, root_dir=os.path.dirname(INPUT_ARRAY) + ) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + model = FNN().to(device) + criterion = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs).squeeze().to(device) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(model.state_dict(), state_dict) + + +def test(model_path): + testset = ArithmeticDataset( + input_array=TEST_INPUT_ARRAY, target_array=TEST_TARGET_ARRAY, root_dir=os.path.dirname(TEST_INPUT_ARRAY) + ) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = FNN().to(device) + net.load_state_dict(torch.load(model_path)) + + error = 0 + total = len(testloader) * batch_size + + with torch.no_grad(): + for data in testloader: + inputs, labels = data[0].to(device), data[1].to(device) + outputs = net(inputs).squeeze().to(device) + error += ((outputs - labels) * (outputs - labels)).sum().data.cpu() + + logger.info(f"Average MSE of the network on the 10000 test inputs: {np.sqrt(error / total):.3f}") + + +def export_model(model_path, onnx_path): + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + x = torch.rand(1, 3, device=device) + + torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) + logger.info("Model exported successfully") + + test_onnx(model_path, onnx_path) + + +def test_onnx(model_path, onnx_path): + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, 3, device=device) + torch_out = model(x) + + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") + + +if __name__ == "__main__": + train(STATE_PATH) + test(STATE_PATH) + export_model(STATE_PATH, ONNX_PATH)