Skip to content
Snippets Groups Projects
Commit e52bc49b authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

Merge branch 'varasse/experiments' into 'master'

Add more models, specifications and tests

See merge request laiser/caisar!134
parents cc52e9d2 61502514
No related branches found
No related tags found
No related merge requests found
Showing
with 208 additions and 0 deletions
......@@ -110,6 +110,18 @@ command = "%e -mp %{nnet-onnx} -pp %f --timeout %t --domain zono --split --score
driver = "%{config}/drivers/pyrat.drv"
use_at_auto_level = 1
[ATP pyrat-arithmetic]
name = "PyRAT"
alternative = "arithmetic"
exec = "pyrat.py"
exec = "pyrat"
version_switch = "--version"
version_regexp = "PyRAT \\([0-9.]+\\)"
version_ok = "1.1"
command = "%e -mp %{nnet-onnx} -pp %f --timeout %t --domain poly --domain zono --split --scorer coef --initial --allow_smaller_size --booster always"
driver = "%{config}/drivers/pyrat.drv"
use_at_auto_level = 1
[ATP nnenum]
name = "nnenum"
exec = "nnenum.sh"
......
File added
theory Arithmethic
use ieee_float.Float64
use int.Int
use caisar.types.Vector
use caisar.model.Model
type input = vector t
let constant eps : t = 0.5
val constant model_filename: string
let constant nn : model = read_model model_filename
predicate valid_input (i: input) =
(-5.0:t) .<= i[0] .<= (5.0:t)
/\ (-5.0:t) .<= i[1] .<= (5.0:t)
/\ (-5.0:t) .<= i[2] .<= (5.0:t)
let runP1 (i: input) : t
requires { has_length i 3 }
requires { valid_input i }
ensures { result .- i[0] .+ i[1] .+ i[2] .<= eps } =
(nn @@ i)[0]
end
File added
import numpy as np
# Given a np array row, return a linear combination of the output
def f(x: np.ndarray) -> np.ndarray:
return np.array((x[0] - x[1] - x[2]))
if __name__ == "__main__":
arr1 = np.random.normal(size=(10000, 3))
arr2 = np.apply_along_axis(f, 1, arr1)
arr3 = np.random.normal(size=(10000, 3))
arr4 = np.apply_along_axis(f, 1, arr1)
np.save(file="data.npy", arr=arr1)
np.save(file="target.npy", arr=arr2)
np.save(file="test_data.npy", arr=arr3)
np.save(file="test_target.npy", arr=arr4)
File added
File added
File added
import os
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.onnx
import torch.optim as optim
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
SEED = 42
STATE_PATH = f"FNN_s{SEED}.pth"
ONNX_PATH = f"FNN_s{SEED}.onnx"
INPUT_ARRAY = "data.npy"
TEST_INPUT_ARRAY = "test_data.npy"
TARGET_ARRAY = "target.npy"
TEST_TARGET_ARRAY = "test_target.npy"
torch.manual_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device {device}")
num_epoch = 2
batch_size = 4
class FNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
class ArithmeticDataset(Dataset):
def __init__(self, input_array, target_array, root_dir):
self.input_array = np.load(input_array).astype(np.float32)
self.target_array = np.load(target_array).astype(np.float32)
self.root_dir = root_dir
def __len__(self):
return len(self.input_array)
def __getitem__(self, idx):
return [self.input_array[idx], self.target_array[idx]]
def train(state_dict):
trainset = ArithmeticDataset(
input_array=INPUT_ARRAY, target_array=TARGET_ARRAY, root_dir=os.path.dirname(INPUT_ARRAY)
)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
model = FNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epoch):
running_loss = 0.0
for i, data in enumerate(trainloader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs).squeeze().to(device)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
logger.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
logger.info("Finished training")
torch.save(model.state_dict(), state_dict)
def test(model_path):
testset = ArithmeticDataset(
input_array=TEST_INPUT_ARRAY, target_array=TEST_TARGET_ARRAY, root_dir=os.path.dirname(TEST_INPUT_ARRAY)
)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
net = FNN().to(device)
net.load_state_dict(torch.load(model_path))
error = 0
total = len(testloader) * batch_size
with torch.no_grad():
for data in testloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs).squeeze().to(device)
error += ((outputs - labels) * (outputs - labels)).sum().data.cpu()
logger.info(f"Average MSE of the network on the 10000 test inputs: {np.sqrt(error / total):.3f}")
def export_model(model_path, onnx_path):
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
x = torch.rand(1, 3, device=device)
torch.onnx.export(model=model, args=x, f=onnx_path, export_params=True)
logger.info("Model exported successfully")
test_onnx(model_path, onnx_path)
def test_onnx(model_path, onnx_path):
model = FNN().to(device)
model.load_state_dict(torch.load(model_path))
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
for _ in range(10000):
x = torch.rand(1, 3, device=device)
torch_out = model(x)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
if __name__ == "__main__":
train(STATE_PATH)
test(STATE_PATH)
export_model(STATE_PATH, ONNX_PATH)
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
loguru
numpy
onnx
onnxruntime
torch
torchaudio
torchvision
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment