Skip to content
Snippets Groups Projects
Commit e19ddcbb authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

[exps] Add binary models (sort 0 from others)

parent f1e893ff
No related branches found
No related tags found
No related merge requests found
File added
File added
...@@ -12,11 +12,13 @@ from torch.nn.utils import prune ...@@ -12,11 +12,13 @@ from torch.nn.utils import prune
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
img_size = 14 img_size = 28
SEED = 42 SEED = 42
STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth" STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth"
ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx" ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx"
BIN_STATE_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.pth"
BIN_ONNX_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.onnx"
PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl" PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl"
PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx" PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx"
FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth" FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth"
...@@ -81,14 +83,35 @@ class FNN(nn.Module): ...@@ -81,14 +83,35 @@ class FNN(nn.Module):
return x return x
class FNN_bin(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(img_size * img_size, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])]) transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])])
def train(state_dict): def train(state_dict, is_bin=False):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
model = FNN().to(device) if is_bin:
model = FNN_bin().to(device)
else:
model = FNN().to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
...@@ -96,8 +119,11 @@ def train(state_dict): ...@@ -96,8 +119,11 @@ def train(state_dict):
running_loss = 0.0 running_loss = 0.0
for i, data in enumerate(trainloader, 0): for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device) inputs, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(inputs).to(device) outputs = model(inputs).to(device)
outputs = model(inputs).to(device)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -112,7 +138,7 @@ def train(state_dict): ...@@ -112,7 +138,7 @@ def train(state_dict):
torch.save(model.state_dict(), state_dict) torch.save(model.state_dict(), state_dict)
def train_2_nets(cnn_path, fnn_path): def train_2_nets(fnn_pre_path, fnn_post_path):
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor()) trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
...@@ -139,16 +165,25 @@ def train_2_nets(cnn_path, fnn_path): ...@@ -139,16 +165,25 @@ def train_2_nets(cnn_path, fnn_path):
logger.info("Finished training") logger.info("Finished training")
torch.save(fnn_pre.state_dict(), cnn_path) torch.save(fnn_pre.state_dict(), fnn_pre_path)
torch.save(fnn_post.state_dict(), fnn_path) torch.save(fnn_post.state_dict(), fnn_post_path)
def test(model_path, is_state_dict=True): def test(model_path, is_state_dict=True, is_bin=False):
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) if is_bin:
classes = (0, 1)
else:
classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
if is_state_dict: if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
net = FNN_bin().to(device)
net.load_state_dict(torch.load(model_path))
elif is_state_dict:
net = FNN().to(device) net = FNN().to(device)
net.load_state_dict(torch.load(model_path)) net.load_state_dict(torch.load(model_path))
else: else:
...@@ -160,12 +195,16 @@ def test(model_path, is_state_dict=True): ...@@ -160,12 +195,16 @@ def test(model_path, is_state_dict=True):
with torch.no_grad(): with torch.no_grad():
for data in testloader: for data in testloader:
images, labels = data[0].to(device), data[1].to(device) images, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
outputs = net(images) outputs = net(images)
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
total += labels.size(0) total += labels.size(0)
correct += (predicted == labels).sum().item() correct += (predicted == labels).sum().item()
if is_state_dict: if is_bin:
logger.info(f"Accuracy of the binary network on the 10000 test images: {100 * correct // total} %")
elif is_state_dict:
logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
else: else:
logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %") logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %")
...@@ -176,6 +215,8 @@ def test(model_path, is_state_dict=True): ...@@ -176,6 +215,8 @@ def test(model_path, is_state_dict=True):
with torch.no_grad(): with torch.no_grad():
for data in testloader: for data in testloader:
images, labels = data[0].to(device), data[1].to(device) images, labels = data[0].to(device), data[1].to(device)
if is_bin:
labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device)
outputs = net(images) outputs = net(images)
_, predictions = torch.max(outputs, 1) _, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions): for label, prediction in zip(labels, predictions):
...@@ -231,8 +272,14 @@ def test_2_nets(fnn_pre_path, fnn_post_path): ...@@ -231,8 +272,14 @@ def test_2_nets(fnn_pre_path, fnn_post_path):
logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %")
def export_model(model_path, onnx_path, is_state_dict=True): def export_model(model_path, onnx_path, is_state_dict=True, is_bin=False):
if is_state_dict: if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
model = FNN_bin().to(device)
model.load_state_dict(torch.load(model_path))
elif is_state_dict:
model = FNN().to(device) model = FNN().to(device)
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
else: else:
...@@ -240,12 +287,14 @@ def export_model(model_path, onnx_path, is_state_dict=True): ...@@ -240,12 +287,14 @@ def export_model(model_path, onnx_path, is_state_dict=True):
x = torch.rand(1, img_size, img_size, device=device) x = torch.rand(1, img_size, img_size, device=device)
torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
if is_state_dict: if is_bin:
logger.info("Binary model exported successfully")
elif is_state_dict:
logger.info("Model exported successfully") logger.info("Model exported successfully")
else: else:
logger.info("Pruned model exported successfully") logger.info("Pruned model exported successfully")
test_onnx(model_path, onnx_path, is_state_dict) test_onnx(model_path, onnx_path, is_state_dict, is_bin)
def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
...@@ -266,8 +315,14 @@ def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): ...@@ -266,8 +315,14 @@ def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx):
test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx) test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx)
def test_onnx(model_path, onnx_path, is_state_dict=True): def test_onnx(model_path, onnx_path, is_state_dict=True, is_bin=False):
if is_state_dict: if is_bin and (not is_state_dict):
raise ValueError("Binary model required to be loaded and exported via state dicts.")
if is_bin:
model = FNN_bin().to(device)
model.load_state_dict(torch.load(model_path))
elif is_state_dict:
model = FNN().to(device) model = FNN().to(device)
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
else: else:
...@@ -290,7 +345,9 @@ def test_onnx(model_path, onnx_path, is_state_dict=True): ...@@ -290,7 +345,9 @@ def test_onnx(model_path, onnx_path, is_state_dict=True):
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
if is_state_dict: if is_bin:
logger.info("Exported binary model has been tested with ONNXRuntime, and the result looks good!")
elif is_state_dict:
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
else: else:
logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!") logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!")
...@@ -374,8 +431,9 @@ def build_pruned_model(state_dict, pruned_model_path): ...@@ -374,8 +431,9 @@ def build_pruned_model(state_dict, pruned_model_path):
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.") # TODO: Add an argument parser?
logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.")
train(STATE_PATH) train(STATE_PATH)
test(STATE_PATH) test(STATE_PATH)
export_model(STATE_PATH, ONNX_PATH) export_model(STATE_PATH, ONNX_PATH)
...@@ -384,7 +442,11 @@ if __name__ == "__main__": ...@@ -384,7 +442,11 @@ if __name__ == "__main__":
export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False) export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False)
logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.") logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.")
train_2_nets(FNN_PRE_PATH, FNN_POST_PATH) train_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
test_2_nets(FNN_PRE_PATH, FNN_POST_PATH) test_2_nets(FNN_PRE_PATH, FNN_POST_PATH)
export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH) export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH)
logger.info("Building, training and exporting to ONNX a binary FNN.")
train(BIN_STATE_PATH, is_bin=True)
test(BIN_STATE_PATH, is_bin=True)
export_model(BIN_STATE_PATH, BIN_ONNX_PATH, is_bin=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment