diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index 775ab20a93d389875b3339fefaa7163b0d35a633..daa570a3a03f218d29be38bb16cea554f4c406f1 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -110,6 +110,18 @@ command = "%e -mp %{nnet-onnx} -pp %f --timeout %t --domain zono --split --score driver = "%{config}/drivers/pyrat.drv" use_at_auto_level = 1 +[ATP pyrat-arithmetic] +name = "PyRAT" +alternative = "arithmetic" +exec = "pyrat.py" +exec = "pyrat" +version_switch = "--version" +version_regexp = "PyRAT \\([0-9.]+\\)" +version_ok = "1.1" +command = "%e -mp %{nnet-onnx} -pp %f --timeout %t --domain poly --domain zono --split --scorer coef --initial --allow_smaller_size --booster always" +driver = "%{config}/drivers/pyrat.drv" +use_at_auto_level = 1 + [ATP nnenum] name = "nnenum" exec = "nnenum.sh" diff --git a/examples/arithmetic/FNN_s42.onnx b/examples/arithmetic/FNN_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..94e1c0b8837b0de4e0369d54291549fb039cc3f1 Binary files /dev/null and b/examples/arithmetic/FNN_s42.onnx differ diff --git a/examples/arithmetic/arithmetic.why b/examples/arithmetic/arithmetic.why new file mode 100644 index 0000000000000000000000000000000000000000..f4fab6e48617fb09d75c7f52314fcc2998ca06e1 --- /dev/null +++ b/examples/arithmetic/arithmetic.why @@ -0,0 +1,24 @@ +theory Arithmethic + use ieee_float.Float64 + use int.Int + use caisar.types.Vector + use caisar.model.Model + + type input = vector t + + let constant eps : t = 0.5 + val constant model_filename: string + let constant nn : model = read_model model_filename + + predicate valid_input (i: input) = + (-5.0:t) .<= i[0] .<= (5.0:t) + /\ (-5.0:t) .<= i[1] .<= (5.0:t) + /\ (-5.0:t) .<= i[2] .<= (5.0:t) + + let runP1 (i: input) : t + requires { has_length i 3 } + requires { valid_input i } + ensures { result .- i[0] .+ i[1] .+ i[2] .<= eps } = + (nn @@ i)[0] + +end diff --git a/examples/arithmetic/data.npy b/examples/arithmetic/data.npy new file mode 100644 index 0000000000000000000000000000000000000000..f630227a9a01ce9b10d6887e55d60181a36294f7 Binary files /dev/null and b/examples/arithmetic/data.npy differ diff --git a/examples/arithmetic/generate_dataset.py b/examples/arithmetic/generate_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0173b63250c42b53ece6e6590059eb245be53a27 --- /dev/null +++ b/examples/arithmetic/generate_dataset.py @@ -0,0 +1,17 @@ +import numpy as np + + +# Given a np array row, return a linear combination of the output +def f(x: np.ndarray) -> np.ndarray: + return np.array((x[0] - x[1] - x[2])) + + +if __name__ == "__main__": + arr1 = np.random.normal(size=(10000, 3)) + arr2 = np.apply_along_axis(f, 1, arr1) + arr3 = np.random.normal(size=(10000, 3)) + arr4 = np.apply_along_axis(f, 1, arr1) + np.save(file="data.npy", arr=arr1) + np.save(file="target.npy", arr=arr2) + np.save(file="test_data.npy", arr=arr3) + np.save(file="test_target.npy", arr=arr4) diff --git a/examples/arithmetic/target.npy b/examples/arithmetic/target.npy new file mode 100644 index 0000000000000000000000000000000000000000..d43f059ec3e7138fcd35f6d9c884d40115419ccf Binary files /dev/null and b/examples/arithmetic/target.npy differ diff --git a/examples/arithmetic/test_data.npy b/examples/arithmetic/test_data.npy new file mode 100644 index 0000000000000000000000000000000000000000..c4b7bd23c002d353eb5c85f186dbe3dd9d3845da Binary files /dev/null and b/examples/arithmetic/test_data.npy differ diff --git a/examples/arithmetic/test_target.npy b/examples/arithmetic/test_target.npy new file mode 100644 index 0000000000000000000000000000000000000000..f92ef12581a8bb43090b230658eaa4761828d231 Binary files /dev/null and b/examples/arithmetic/test_target.npy differ diff --git a/examples/arithmetic/train.py b/examples/arithmetic/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2661fc21f5e592cae803d29bee9a62e72c5e3b0e --- /dev/null +++ b/examples/arithmetic/train.py @@ -0,0 +1,148 @@ +import os + +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.onnx +import torch.optim as optim +from loguru import logger +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +SEED = 42 +STATE_PATH = f"FNN_s{SEED}.pth" +ONNX_PATH = f"FNN_s{SEED}.onnx" +INPUT_ARRAY = "data.npy" +TEST_INPUT_ARRAY = "test_data.npy" +TARGET_ARRAY = "target.npy" +TEST_TARGET_ARRAY = "test_target.npy" + +torch.manual_seed(SEED) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device {device}") +num_epoch = 2 +batch_size = 4 + + +class FNN(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(3, 128) + self.fc2 = nn.Linear(128, 128) + self.fc3 = nn.Linear(128, 1) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + +class ArithmeticDataset(Dataset): + def __init__(self, input_array, target_array, root_dir): + self.input_array = np.load(input_array).astype(np.float32) + self.target_array = np.load(target_array).astype(np.float32) + self.root_dir = root_dir + + def __len__(self): + return len(self.input_array) + + def __getitem__(self, idx): + return [self.input_array[idx], self.target_array[idx]] + + +def train(state_dict): + trainset = ArithmeticDataset( + input_array=INPUT_ARRAY, target_array=TARGET_ARRAY, root_dir=os.path.dirname(INPUT_ARRAY) + ) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + model = FNN().to(device) + criterion = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs).squeeze().to(device) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(model.state_dict(), state_dict) + + +def test(model_path): + testset = ArithmeticDataset( + input_array=TEST_INPUT_ARRAY, target_array=TEST_TARGET_ARRAY, root_dir=os.path.dirname(TEST_INPUT_ARRAY) + ) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = FNN().to(device) + net.load_state_dict(torch.load(model_path)) + + error = 0 + total = len(testloader) * batch_size + + with torch.no_grad(): + for data in testloader: + inputs, labels = data[0].to(device), data[1].to(device) + outputs = net(inputs).squeeze().to(device) + error += ((outputs - labels) * (outputs - labels)).sum().data.cpu() + + logger.info(f"Average MSE of the network on the 10000 test inputs: {np.sqrt(error / total):.3f}") + + +def export_model(model_path, onnx_path): + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + x = torch.rand(1, 3, device=device) + + torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) + logger.info("Model exported successfully") + + test_onnx(model_path, onnx_path) + + +def test_onnx(model_path, onnx_path): + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, 3, device=device) + torch_out = model(x) + + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") + + +if __name__ == "__main__": + train(STATE_PATH) + test(STATE_PATH) + export_model(STATE_PATH, ONNX_PATH) diff --git a/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a8bdf9d58a5ae4b51d570bb56a876e59787214ee Binary files /dev/null and b/examples/mnist/nets/binary/FNN_14x14_bin0_s42.onnx differ diff --git a/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05840a584c7a64534a6f02d1346d33d6eb2e27d6 Binary files /dev/null and b/examples/mnist/nets/binary/FNN_28x28_bin0_s42.onnx differ diff --git a/examples/mnist/nets/pruned/FNN_14x14_pruned_s42.onnx b/examples/mnist/nets/pruned/FNN_14x14_pruned_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..69d394e349bba8c8b8e54a9169e2a23c01182aa9 Binary files /dev/null and b/examples/mnist/nets/pruned/FNN_14x14_pruned_s42.onnx differ diff --git a/examples/mnist/nets/pruned/FNN_14x14_s42.onnx b/examples/mnist/nets/pruned/FNN_14x14_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..21aeef870bb75a874282cc6769258687737548bc Binary files /dev/null and b/examples/mnist/nets/pruned/FNN_14x14_s42.onnx differ diff --git a/examples/mnist/nets/pruned/FNN_28x28_pruned_s42.onnx b/examples/mnist/nets/pruned/FNN_28x28_pruned_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a32cd072b7606a731da5e6dd186ad9c014abe541 Binary files /dev/null and b/examples/mnist/nets/pruned/FNN_28x28_pruned_s42.onnx differ diff --git a/examples/mnist/nets/pruned/FNN_28x28_s42.onnx b/examples/mnist/nets/pruned/FNN_28x28_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e8505c08400e5465e0fdc56e11a0d66ed5c9327f Binary files /dev/null and b/examples/mnist/nets/pruned/FNN_28x28_s42.onnx differ diff --git a/examples/mnist/nets/splitted/FNN_28x28_post_s42.onnx b/examples/mnist/nets/splitted/FNN_28x28_post_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..13bae9a89f7b8167b259c8572625c7ea4b24e378 Binary files /dev/null and b/examples/mnist/nets/splitted/FNN_28x28_post_s42.onnx differ diff --git a/examples/mnist/nets/splitted/FNN_28x28_pre_s42.onnx b/examples/mnist/nets/splitted/FNN_28x28_pre_s42.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2635e41b61b33c501cfbbeace5f5e2b27616bf27 Binary files /dev/null and b/examples/mnist/nets/splitted/FNN_28x28_pre_s42.onnx differ diff --git a/examples/mnist/nets/training_nns/FNN.onnx b/examples/mnist/nets/training_nns/FNN.onnx new file mode 100644 index 0000000000000000000000000000000000000000..204c999b50e8f3bb98c6f068cffd31cf4cc9c290 Binary files /dev/null and b/examples/mnist/nets/training_nns/FNN.onnx differ diff --git a/examples/mnist/nets/training_nns/pruned_FNN.onnx b/examples/mnist/nets/training_nns/pruned_FNN.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa7b7f3d7348861d2573edc20e3873e53051cc57 Binary files /dev/null and b/examples/mnist/nets/training_nns/pruned_FNN.onnx differ diff --git a/examples/mnist/nets/training_nns/requirements.txt b/examples/mnist/nets/training_nns/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fb1bde4ebee5bdcf8c2ccd1b4e1aa07974b6cfb --- /dev/null +++ b/examples/mnist/nets/training_nns/requirements.txt @@ -0,0 +1,7 @@ +loguru +numpy +onnx +onnxruntime +torch +torchaudio +torchvision diff --git a/examples/mnist/nets/training_nns/train.py b/examples/mnist/nets/training_nns/train.py new file mode 100644 index 0000000000000000000000000000000000000000..73b7e9bf3d52e749f88cbca64e3eef3253d5904e --- /dev/null +++ b/examples/mnist/nets/training_nns/train.py @@ -0,0 +1,452 @@ +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.onnx +import torch.optim as optim +import torchvision +from loguru import logger +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import prune +from torch.utils.data import DataLoader +from torchvision import transforms + +img_size = 28 + +SEED = 42 +STATE_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.pth" +ONNX_PATH = f"FNN_{img_size}x{img_size}_s{SEED}.onnx" +BIN_STATE_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.pth" +BIN_ONNX_PATH = f"FNN_{img_size}x{img_size}_bin0_s{SEED}.onnx" +PRUNED_MODEL_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.pkl" +PRUNED_ONNX_PATH = f"FNN_{img_size}x{img_size}_pruned_s{SEED}.onnx" +FNN_PRE_PATH = f"FNN_28x28_pre_s{SEED}.pth" +FNN_PRE_ONNX_PATH = f"FNN_28x28_pre_s{SEED}.onnx" +FNN_POST_PATH = f"FNN_28x28_post_s{SEED}.pth" +FNN_POST_ONNX_PATH = f"FNN_28x28_post_s{SEED}.onnx" + +torch.manual_seed(SEED) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device {device}") +num_epoch = 2 +batch_size = 4 + + +class FNN_pre(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(28 * 28, 512) + self.fc2 = nn.Linear(512, 512) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + return x + + +class FNN_post(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(512, 256) + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + +class FNN(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(img_size * img_size, 10) + self.fc2 = nn.Linear(10, 10) + self.fc3 = nn.Linear(10, 10) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + +class FNN_bin(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(img_size * img_size, 10) + self.fc2 = nn.Linear(10, 10) + self.fc3 = nn.Linear(10, 2) + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + + +transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([img_size, img_size])]) + + +def train(state_dict, is_bin=False): + trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + if is_bin: + model = FNN_bin().to(device) + else: + model = FNN().to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + inputs, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) + optimizer.zero_grad() + outputs = model(inputs).to(device) + outputs = model(inputs).to(device) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(model.state_dict(), state_dict) + + +def train_2_nets(fnn_pre_path, fnn_post_path): + trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor()) + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + fnn_pre = FNN_pre().to(device) + fnn_post = FNN_post().to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(fnn_post.parameters(), lr=0.001, momentum=0.9) + + for epoch in range(num_epoch): + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs_pre = fnn_pre(inputs).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + loss = criterion(outputs_post, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 2000 == 1999: + logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training") + + torch.save(fnn_pre.state_dict(), fnn_pre_path) + torch.save(fnn_post.state_dict(), fnn_post_path) + + +def test(model_path, is_state_dict=True, is_bin=False): + if is_bin: + classes = (0, 1) + else: + classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + net = FNN_bin().to(device) + net.load_state_dict(torch.load(model_path)) + elif is_state_dict: + net = FNN().to(device) + net.load_state_dict(torch.load(model_path)) + else: + net = torch.load(model_path) + + correct = 0 + total = 0 + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) + outputs = net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + if is_bin: + logger.info(f"Accuracy of the binary network on the 10000 test images: {100 * correct // total} %") + elif is_state_dict: + logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + else: + logger.info(f"Accuracy of the pruned network on the 10000 test images: {100 * correct // total} %") + + correct_pred = {classname: 0 for classname in classes} + total_pred = {classname: 0 for classname in classes} + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + if is_bin: + labels = torch.Tensor([1 if item == 0 else 0 for item in labels]).long().to(device) + outputs = net(images) + _, predictions = torch.max(outputs, 1) + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") + + +def test_2_nets(fnn_pre_path, fnn_post_path): + classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor()) + testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + + correct = 0 + total = 0 + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs_pre = fnn_pre(images).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + _, predicted = torch.max(outputs_post.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + logger.info(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + + correct_pred = {classname: 0 for classname in classes} + total_pred = {classname: 0 for classname in classes} + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(device), data[1].to(device) + outputs_pre = fnn_pre(images).to(device) + outputs_post = fnn_post(outputs_pre).to(device) + _, predictions = torch.max(outputs_post, 1) + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + logger.info(f"Accuracy for class: {classname} is {accuracy:.1f} %") + + +def export_model(model_path, onnx_path, is_state_dict=True, is_bin=False): + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + model = FNN_bin().to(device) + model.load_state_dict(torch.load(model_path)) + elif is_state_dict: + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + else: + model = torch.load(model_path) + x = torch.rand(1, img_size, img_size, device=device) + + torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True) + if is_bin: + logger.info("Binary model exported successfully") + elif is_state_dict: + logger.info("Model exported successfully") + else: + logger.info("Pruned model exported successfully") + + test_onnx(model_path, onnx_path, is_state_dict, is_bin) + + +def export_2_models(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + x = torch.rand(1, 28, 28, device=device) + + torch.onnx.export(model=fnn_pre, args=x, f=fnn_pre_onnx, export_params=True) + logger.info("First model exported successfully") + + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + y = fnn_pre(x).to(device) + + torch.onnx.export(model=fnn_post, args=y, f=fnn_post_onnx, export_params=True) + logger.info("Second model exported successfully") + + test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx) + + +def test_onnx(model_path, onnx_path, is_state_dict=True, is_bin=False): + if is_bin and (not is_state_dict): + raise ValueError("Binary model required to be loaded and exported via state dicts.") + + if is_bin: + model = FNN_bin().to(device) + model.load_state_dict(torch.load(model_path)) + elif is_state_dict: + model = FNN().to(device) + model.load_state_dict(torch.load(model_path)) + else: + model = torch.load(model_path) + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, img_size, img_size, device=device) + torch_out = model(x) + + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + if is_bin: + logger.info("Exported binary model has been tested with ONNXRuntime, and the result looks good!") + elif is_state_dict: + logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!") + else: + logger.info("Exported pruned model has been tested with ONNXRuntime, and the result looks good!") + + +def test_2_onnx(fnn_pre_path, fnn_pre_onnx, fnn_post_path, fnn_post_onnx): + fnn_pre = FNN_pre().to(device) + fnn_pre.load_state_dict(torch.load(fnn_pre_path)) + fnn_post = FNN_post().to(device) + fnn_post.load_state_dict(torch.load(fnn_post_path)) + + onnx_model_pre = onnx.load(fnn_pre_onnx) + onnx.checker.check_model(onnx_model_pre) + onnx_model_post = onnx.load(fnn_post_onnx) + onnx.checker.check_model(onnx_model_post) + + ort_session_pre = ort.InferenceSession(fnn_pre_onnx, providers=["CPUExecutionProvider"]) + ort_session_post = ort.InferenceSession(fnn_post_onnx, providers=["CPUExecutionProvider"]) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + for _ in range(10000): + x = torch.rand(1, 28, 28, device=device) + torch_out_pre = fnn_pre(x) + + ort_inputs_pre = {ort_session_pre.get_inputs()[0].name: to_numpy(x)} + ort_outs_pre = ort_session_pre.run(None, ort_inputs_pre) + + np.testing.assert_allclose(to_numpy(torch_out_pre), ort_outs_pre[0], rtol=1e-03, atol=1e-05) + + logger.info("First exported model has been tested with ONNXRuntime, and the result looks good!") + + for _ in range(10000): + x = torch.rand(1, 512, device=device) + torch_out_post = fnn_post(x) + + ort_inputs_post = {ort_session_post.get_inputs()[0].name: to_numpy(x)} + ort_outs_post = ort_session_post.run(None, ort_inputs_post) + + np.testing.assert_allclose(to_numpy(torch_out_post), ort_outs_post[0], rtol=1e-03, atol=1e-05) + + logger.info("Second exported model has been tested with ONNXRuntime, and the result looks good!") + + +def build_pruned_model(state_dict, pruned_model_path): + model = FNN().to(device) + model.load_state_dict(torch.load(state_dict)) + + parameters_to_prune = ((model.fc1, "weight"), (model.fc2, "weight"), (model.fc3, "weight")) + prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2) + + logger.info( + "Sparsity in fc1.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement()) + ) + ) + logger.info( + "Sparsity in fc2.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement()) + ) + ) + logger.info( + "Sparsity in fc3.weight: {:.2f}%".format( + 100.0 * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement()) + ) + ) + logger.info( + "Global sparsity: {:.2f}%".format( + 100.0 + * float( + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) + ) + / float(model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement()) + ) + ) + + logger.info("Finished pruning") + + torch.save(model, pruned_model_path) + + +if __name__ == "__main__": + # TODO: Add an argument parser? + + logger.info("Building, training and exporting to ONNX a simple FNN and a pruned FNN.") + train(STATE_PATH) + test(STATE_PATH) + export_model(STATE_PATH, ONNX_PATH) + build_pruned_model(STATE_PATH, PRUNED_MODEL_PATH) + test(PRUNED_MODEL_PATH, is_state_dict=False) + export_model(PRUNED_MODEL_PATH, PRUNED_ONNX_PATH, is_state_dict=False) + + logger.info("Building, training and exporting to ONNX two FNNs that are to be used successively.") + train_2_nets(FNN_PRE_PATH, FNN_POST_PATH) + test_2_nets(FNN_PRE_PATH, FNN_POST_PATH) + export_2_models(FNN_PRE_PATH, FNN_PRE_ONNX_PATH, FNN_POST_PATH, FNN_POST_ONNX_PATH) + + logger.info("Building, training and exporting to ONNX a binary FNN.") + train(BIN_STATE_PATH, is_bin=True) + test(BIN_STATE_PATH, is_bin=True) + export_model(BIN_STATE_PATH, BIN_ONNX_PATH, is_bin=True) diff --git a/examples/onnx_rewrite/comparison.mlw b/examples/onnx_rewrite/comparison.mlw new file mode 100644 index 0000000000000000000000000000000000000000..c56c3d599c6d7f3eb801899b82e0c722ef7c92bd --- /dev/null +++ b/examples/onnx_rewrite/comparison.mlw @@ -0,0 +1,36 @@ +theory COMPARISON + + use ieee_float.Float64 + use caisar.types.Float64WithBounds as Feature + use caisar.types.IntWithBounds as Label + use caisar.types.Vector + use caisar.model.Model + use caisar.dataset.CSV + use caisar.robust.ClassRobustCSV + use caisar.robust.ClassRobustVector + + constant model_filename_1: string + constant model_filename_2: string + constant dataset_filename: string + + constant label_bounds: Label.bounds = Label.{ lower = 0; upper = 9 } + constant feature_bounds: Feature.bounds = Feature.{ lower = (0.0:t); upper = (1.0:t) } + + goal comparison: + let nn_1 = read_model model_filename_1 in + let nn_2 = read_model model_filename_2 in + let dataset = read_dataset dataset_filename in + let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + let delta = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + CSV.forall_ dataset (fun _ e -> + forall perturbed_e. + has_length perturbed_e (length e) -> + FeatureVector.valid feature_bounds perturbed_e -> + let perturbation = perturbed_e - e in + ClassRobustVector.bounded_by_epsilon perturbation eps -> + let out_1 = nn_1@@perturbed_e in + let out_2 = nn_2@@perturbed_e in + .- delta .<= out_1[0] .- out_2[0] .<= delta + ) + +end diff --git a/examples/onnx_rewrite/sequencing.mlw b/examples/onnx_rewrite/sequencing.mlw new file mode 100644 index 0000000000000000000000000000000000000000..a3a03cf7ccd0e3bb3747e345f7a08de78bdea117 --- /dev/null +++ b/examples/onnx_rewrite/sequencing.mlw @@ -0,0 +1,36 @@ +theory SEQUENCING + + use ieee_float.Float64 + use caisar.types.Float64WithBounds as Feature + use caisar.types.IntWithBounds as Label + use caisar.types.Vector + use caisar.model.Model + use caisar.dataset.CSV + use caisar.robust.ClassRobustCSV + use caisar.robust.ClassRobustVector + + constant model_filename_1: string + constant model_filename_2: string + constant dataset_filename: string + + constant label_bounds: Label.bounds = Label.{ lower = 0; upper = 9 } + constant feature_bounds: Feature.bounds = Feature.{ lower = (0.0:t); upper = (1.0:t) } + + goal sequencing: + let nn_1 = read_model model_filename_1 in + let nn_2 = read_model model_filename_2 in + let dataset = read_dataset dataset_filename in + let eps = (0.0100000000000000002081668171172168513294309377670288085937500000:t) in + CSV.forall_ dataset (fun l e -> + forall perturbed_e. + has_length perturbed_e (length e) -> + FeatureVector.valid feature_bounds perturbed_e -> + let perturbation = perturbed_e - e in + ClassRobustVector.bounded_by_epsilon perturbation eps -> + let out_1 = nn_1@@perturbed_e in + let out_2 = nn_2@@out_1 in + forall j. Label.valid label_bounds j -> j <> l -> + out_2[l] .>= out_2[j] + ) + +end diff --git a/lib/nir/node.ml b/lib/nir/node.ml index 94bb663f3cdb4dca095ad62dce6b50c357ba1345..07ed4251662545616a01d8010e393e68de09d16c 100644 --- a/lib/nir/node.ml +++ b/lib/nir/node.ml @@ -55,8 +55,8 @@ type descr = inputC : t option; alpha : float; beta : float; - transA : bool; - transB : bool; + transA : int; + transB : int; } | LogSoftmax | ReLu of { input : t } @@ -183,7 +183,7 @@ and compute_shape_descr = function d1 := !d1 * Shape.get shape i done; for i = axis to Shape.rank shape - 1 do - d2 := !d1 * Shape.get shape i + d2 := !d2 * Shape.get shape i done; Shape.of_list [ !d1; !d2 ] | Input { shape } -> shape @@ -290,7 +290,7 @@ and compute_shape_descr = function | [| k; n |] -> (k, n) | _ -> failwith "Gemm input must be of size 2" in - let tr trans (k, n) = if trans then (n, k) else (k, n) in + let tr trans (k, n) = if trans = 1 then (n, k) else (k, n) in let a1, a2 = tr transA @@ rank2 inputA in let b1, b2 = tr transB @@ rank2 inputB in if not (Int.equal a2 b1) diff --git a/lib/nir/node.mli b/lib/nir/node.mli index f2ff5c6f8e4b2b6f7db3e76c54717c3ef4a961e2..ff0ba8c52f30f2ae98229825902fd8006bddf8b3 100644 --- a/lib/nir/node.mli +++ b/lib/nir/node.mli @@ -72,8 +72,8 @@ type descr = inputC : t option; alpha : float; beta : float; - transA : bool; - transB : bool; + transA : int; + transB : int; } | LogSoftmax | ReLu of { input : t } diff --git a/lib/onnx/reader.ml b/lib/onnx/reader.ml index ed5c87f18032e1fe3ca78e0a9b88da0d81096076..e898280aff552fbb50d272356f799009fe32a07a 100644 --- a/lib/onnx/reader.ml +++ b/lib/onnx/reader.ml @@ -189,37 +189,43 @@ end = struct (module String) (List.map ~f:(fun a -> (Option.value_exn a.name, a)) n.attribute) in - let get_float name : float = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } - -> - f - | _ -> failwith "Attribute wrongly typed" + let get_attr ?default name m = + match Hashtbl.find attrs name with + | Some v -> m v + | None -> ( + match default with + | Some v -> v + | None -> Fmt.failwith "Required attribute %s missing" name) in - let get_int name : int = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> - Int64.to_int_exn i - | _ -> failwith "Attribute wrongly typed" + let get_float ?default name : float = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } + -> + f + | _ -> failwith "Attribute wrongly typed") in - let get_ints name : int list = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> - List.map ~f:Int64.to_int_exn l - | _ -> failwith "Attribute wrongly typed" + let get_int ?default name : int = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } + -> + Int64.to_int_exn i + | _ -> failwith "Attribute wrongly typed") in - let get_bool name : bool = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> - not (Int64.equal i 0L) - | _ -> failwith "Attribute wrongly typed" + let get_ints ?default name : int list = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> + List.map ~f:Int64.to_int_exn l + | _ -> failwith "Attribute wrongly typed") in - let get_tensor name : Nir.Gentensor.t = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.TENSOR; t = Some t; _ } - -> - convert_tensor t - | _ -> failwith "Attribute wrongly typed" + let get_tensor ?default name : Nir.Gentensor.t = + get_attr ?default name (function + | { + type' = Some AttributeProto.AttributeType.TENSOR; + t = Some t; + _; + } -> + convert_tensor t + | _ -> failwith "Attribute wrongly typed") in let n' = match n.op_type with @@ -257,10 +263,10 @@ end = struct inputA = convert inputA; inputB = convert inputB; inputC = Option.map ~f:convert inputC; - alpha = get_float "alpha"; - beta = get_float "beta"; - transA = get_bool "transA"; - transB = get_bool "transB"; + alpha = get_float ~default:1.0 "alpha"; + beta = get_float ~default:1.0 "beta"; + transA = get_int ~default:0 "transA"; + transB = get_int ~default:0 "transB"; } | "LogSoftmax" -> Nir.Node.LogSoftmax | "Transpose" -> diff --git a/lib/onnx/writer.ml b/lib/onnx/writer.ml index 0b1eb3f3e66f1b758e6f000ef5d3dc7b90e5e045..b6d3c4650abe0bb1b23945d09ad65ea28bcd716d 100644 --- a/lib/onnx/writer.ml +++ b/lib/onnx/writer.ml @@ -86,8 +86,8 @@ let nir_to_onnx_protoc (nir : Nir.Ngraph.t) = let mk_float name f = AttributeProto.make ~name ~type':FLOAT ~f () in let mk_tensor name t = AttributeProto.make ~name ~type':TENSOR ~t () in match v.descr with - | Gemm _ | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv - | Identity _ | RW_Linearized_ReLu | GatherND _ | ReduceSum _ -> + | LogSoftmax | Transpose _ | Squeeze _ | MaxPool | Conv | Identity _ + | RW_Linearized_ReLu | GatherND _ | ReduceSum _ -> Caisar_logging.Logging.not_implemented_yet (fun m -> m "Operator %a not implemented yet." Nir.Node.pp_descr v.descr) | Reshape _ -> make "Reshape" [] @@ -115,6 +115,14 @@ let nir_to_onnx_protoc (nir : Nir.Ngraph.t) = mk_float "seed" seed; mk_ints "shape" (Array.to_list shape); ] + | Gemm { alpha; beta; transA; transB; _ } -> + make "Gemm" + [ + mk_float "alpha" alpha; + mk_float "beta" beta; + mk_int "transA" transA; + mk_int "transB" transB; + ] in Nir.Ngraph.iter_vertex vertex_to_protoc nir; (Queue.to_list acc, Option.value_exn !g_input) diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml index 6f079ff1c9579000613adab0951dad4d0e780e98..29f27b7bde4b21810e41fb9b40def245457083e9 100644 --- a/src/interpretation/interpreter_theory.ml +++ b/src/interpretation/interpreter_theory.ml @@ -128,7 +128,11 @@ module Vector = struct interpreter_op) | None -> IRE.reconstruct_term ()) | [ Term _t1; Term _t2 ] -> IRE.reconstruct_term () - | _ -> fail_on_unexpected_argument ls + | _ -> Logging.code_error ~src:Logging.src_interpret_goal (fun m -> + m "Unexpected argument(s) for '%a': %a" Why3.Pretty.print_ls ls + (Fmt.list ~sep:Fmt.comma IRE.pp_value) vl + ) + let length : _ IRE.builtin = fun engine ls vl _ty -> diff --git a/tests/arithmetic.t b/tests/arithmetic.t new file mode 100644 index 0000000000000000000000000000000000000000..4d8ba2b1f217604f4681b71dfad42c87848371c1 --- /dev/null +++ b/tests/arithmetic.t @@ -0,0 +1,27 @@ + $ . ./setup_env.sh + + $ caisar verify --prover PyRAT --ltag=ProverSpec --ltag=StackTrace --ltag=InterpretGoal --goal :runP1\'vc --define model_filename:FNN_s42.onnx ../examples/arithmetic/arithmetic.why + [DEBUG]{InterpretGoal} Interpreted formula for goal 'runP1'vc': + forall x:t, x1:t, x2:t. + (le ((- 5.0):t) x /\ le x (5.0:t)) /\ + (le ((- 5.0):t) x1 /\ le x1 (5.0:t)) /\ le ((- 5.0):t) x2 /\ le x2 (5.0:t) -> + le (add RNE (add RNE (sub RNE (nn_onnx @@ vector x x1 x2)[0] x) x1) x2) + (0.5:t) + vector, 3 + nn_onnx, + (Interpreter_types.Model + (Interpreter_types.ONNX, { Language.nn_nb_inputs = 3; nn_nb_outputs = 1; + nn_ty_elt = t; + nn_filename = + "../examples/arithmetic/FNN_s42.onnx"; + nn_format = <nir> })) + [DEBUG]{ProverSpec} Prover-tailored specification: + -5.0 <= x0 + x0 <= 5.0 + -5.0 <= x1 + x1 <= 5.0 + -5.0 <= x2 + x2 <= 5.0 + y0 <= 0.5 + + Goal runP1'vc: Unknown () diff --git a/tests/autodetect.t b/tests/autodetect.t index aa34864e8f786e1b755212f587977d2422c589ea..eadc0747b19b3b99dddc8d77b99fa7e5d8d58936 100644 --- a/tests/autodetect.t +++ b/tests/autodetect.t @@ -43,6 +43,7 @@ Test autodetect PyRAT 1.1 (ACAS) PyRAT 1.1 (ACASd) PyRAT 1.1 (VNNLIB) + PyRAT 1.1 (arithmetic) SAVer v1.0 alpha-beta-CROWN dummy-version alpha-beta-CROWN dummy-version (ACAS) diff --git a/tests/comparison.t b/tests/comparison.t new file mode 100644 index 0000000000000000000000000000000000000000..a901ae8bbbb9a7e9d8dfa9445d0b74c22541048a --- /dev/null +++ b/tests/comparison.t @@ -0,0 +1,10 @@ + $ . ./setup_env.sh + + $ ls ../examples/ + acasxu + mnist + onnx_rewrite + + $ caisar verify --prover PyRAT --define model_filename_1:../mnist/nets/pruned/FNN_28x28_s42.onnx --define model_filename_2:../mnist/nets/pruned/FNN_28x28_pruned_s42.onnx --define dataset_filename:../mnist/csv/single_image.csv ../examples/onnx_rewrite/comparison.mlw -v + [INFO] Verification results for theory 'COMPARISON' + Goal comparison: Unknown () diff --git a/tests/dune b/tests/dune index d2e97626573761ff466af87eddfdb73324d13198..57ad5a8b44642c58d7b71467d5250f22bfb4ae90 100644 --- a/tests/dune +++ b/tests/dune @@ -1,6 +1,6 @@ (cram (alias local) - (applies_to * \ nir_to_onnx acasxu_ci) + (applies_to * \ nir_to_onnx acasxu_ci arithmetic comparison sequencing) (deps (package caisar) setup_env.sh @@ -10,9 +10,54 @@ (glob_files bin/*) filter_tmpdir.sh ../lib/xgboost/example/california.csv - ../lib/xgboost/example/california.json) + ../lib/xgboost/example/california.json + ) + (package caisar)) + + (cram + (alias local) + (applies_to arithmetic) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/arithmetic/arithmetic.why + ../examples/arithmetic/FNN_s42.onnx + ) + (package caisar)) + + (cram + (alias local) + (applies_to comparison) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/onnx_rewrite/comparison.mlw + ../examples/mnist/nets/pruned/FNN_28x28_s42.onnx + ../examples/mnist/nets/pruned/FNN_28x28_pruned_s42.onnx + ../examples/mnist/csv/single_image.csv + ) (package caisar)) + (cram + (alias local) + (applies_to sequencing) + (deps + (package caisar) + setup_env.sh + (glob_files bin/*) + filter_tmpdir.sh + ../examples/onnx_rewrite/sequencing.mlw + ../examples/mnist/nets/splitted/FNN_28x28_pre_s42.onnx + ../examples/mnist/nets/splitted/FNN_28x28_post_s42.onnx + ../examples/mnist/csv/single_image.csv + ) + (package caisar)) + + (cram (alias ci) (deps diff --git a/tests/sequencing.t b/tests/sequencing.t new file mode 100644 index 0000000000000000000000000000000000000000..90676de048b7b3b4230fc3983428cfa4de113361 --- /dev/null +++ b/tests/sequencing.t @@ -0,0 +1,10 @@ + $ . ./setup_env.sh + + $ ls ../examples/ + acasxu + mnist + onnx_rewrite + + $ caisar verify --prover PyRAT --define model_filename_1:../mnist/nets/splitted/FNN_28x28_pre_s42.onnx --define model_filename_2:../mnist/nets/splitted/FNN_28x28_post_s42.onnx --define dataset_filename:../mnist/csv/single_image.csv ../examples/onnx_rewrite/sequencing.mlw -v + [INFO] Verification results for theory 'SEQUENCING' + Goal sequencing: Unknown ()