Skip to content
Snippets Groups Projects
Commit ba4cc441 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[config][test][ci] Added proper autodetect for CI

parent fda6e3b3
No related branches found
No related tags found
No related merge requests found
...@@ -45,7 +45,7 @@ test: ...@@ -45,7 +45,7 @@ test:
stage: test stage: test
script: script:
- nix --extra-experimental-features "nix-command flakes" build - nix --extra-experimental-features "nix-command flakes" build
- nix --extra-experimental-features "nix-command flakes" flake check - nix --extra-experimental-features "nix-command flakes" flake check -L
## Manual generation of the documentation ## Manual generation of the documentation
......
...@@ -20,60 +20,19 @@ class WrongNetFormat(Exception): ...@@ -20,60 +20,19 @@ class WrongNetFormat(Exception):
import argparse import argparse
from importlib.metadata import version
import os import os
import pathlib import pathlib
import shutil import shutil
import subprocess import subprocess
from importlib.metadata import version
import sys import sys
import tempfile import tempfile
import warnings
warnings.filterwarnings("ignore")
from maraboupy import Marabou from maraboupy import Marabou
from maraboupy import MarabouCore from maraboupy import MarabouCore
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
def main():
args, unknown = arguments().parse_known_args()
marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
if args.display_version:
print(f"Maraboupy version {version('maraboupy')}")
else:
query, _network = createQuery(args)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name
timeout = args.timeout
MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown)
subprocess.run(
[marabou_binary]
+ ["--input-query={}".format(name)]
+ ["--timeout={}".format(timeout)]
+ unknown
)
os.remove(name)
def createQuery(args):
assert args.prop != None
networkPath = args.network
suffix = networkPath.split(".")[-1]
if suffix == "nnet":
network = Marabou.read_nnet(networkPath)
elif suffix == "pb":
network = Marabou.read_tf(networkPath)
elif suffix == "onnx":
network = Marabou.read_onnx(networkPath)
else:
raise WrongNetFormat(networkPath)
query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop)
return query, network
def arguments(): def arguments():
...@@ -104,14 +63,51 @@ def arguments(): ...@@ -104,14 +63,51 @@ def arguments():
help="The path to Marabou binary", help="The path to Marabou binary",
) )
parser.add_argument( parser.add_argument(
"--display-version", "--version",
action="store_true", action="store_true",
help="Output Maraboupy version and exit", help="Output a version string if Maraboupy is present and exit",
) )
parser.set_defaults(display_version=False) parser.set_defaults(version=False)
return parser return parser
def main():
args, unknown = arguments().parse_known_args()
marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
else:
if args.version:
print(f"Maraboupy version {version('maraboupy')}")
else:
assert args.prop != None
networkPath = args.network
suffix = networkPath.split(".")[-1]
if suffix == "nnet":
network = Marabou.read_nnet(networkPath)
elif suffix == "pb":
network = Marabou.read_tf(networkPath)
elif suffix == "onnx":
network = Marabou.read_onnx(networkPath)
else:
raise WrongNetFormat(networkPath)
query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name
timeout = args.timeout
MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown)
subprocess.run(
[marabou_binary]
+ ["--input-query={}".format(name)]
+ ["--timeout={}".format(timeout)]
+ unknown
)
os.remove(name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -56,10 +56,10 @@ use_at_auto_level = 1 ...@@ -56,10 +56,10 @@ use_at_auto_level = 1
[ATP maraboupy] [ATP maraboupy]
name = "Maraboupy" name = "Maraboupy"
exec = "marabou_eval.py" exec = "marabou_eval.py"
version_switch = "--display-version" version_switch = "--version"
version_regexp = "Maraboupy version \\([0-9.+]+\\)" version_regexp = "Maraboupy version \\([0-9.]+\\)"
version_ok = "2.0.0" version_ok = "2.0.0"
command = "%e %{nnet-onnx} -q %f " command = "%e %{nnet-onnx} %f "
driver = "%{config}/drivers/marabou.drv" driver = "%{config}/drivers/marabou.drv"
use_at_auto_level = 1 use_at_auto_level = 1
......
...@@ -30,7 +30,7 @@ Test autodetect ...@@ -30,7 +30,7 @@ Test autodetect
$ bin/abcrown.sh --version $ bin/abcrown.sh --version
dummy-version dummy-version
$ bin/marabou_eval.py --display-version 2>/dev/null $ bin/marabou_eval.py --version
Maraboupy version 2.0.0 Maraboupy version 2.0.0
$ caisar config -d $ caisar config -d
......
#!/usr/bin/env python3 #!/bin/sh -e
"""
Adapted for quick evaluation of Marabou for CAISAR case $1 in
--version)
Top contributors (to current version): echo "Maraboupy version 2.0.0"
- Andrew Wu ;;
*)
This file is part of the Marabou project. echo "PWD: $(pwd)"
Copyright (c) 2017-2021 by the authors listed in the file AUTHORS echo "NN: $2"
in the top-level source directory) and their institutional affiliations. test -e $2 || (echo "Cannot find the NN file" && exit 1)
All rights reserved. See the file COPYING in the top-level source echo "Goal:"
directory for licensing information. cat $4
""" echo "Result = Unknown"
esac
class WrongNetFormat(Exception):
def __init__(self, networkPath):
self.message = f"Network {networkPath} not in onnx, pb or nnet format!"
super().__init__(self.message)
import argparse
import os
import pathlib
import shutil
import subprocess
from importlib.metadata import version
import sys
import tempfile
from maraboupy import Marabou
from maraboupy import MarabouCore
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
def main():
args, unknown = arguments().parse_known_args()
marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
if args.display_version:
print(f"Maraboupy version {version('maraboupy')}")
else:
query, _network = createQuery(args)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name
timeout = args.timeout
MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown)
subprocess.run(
[marabou_binary]
+ ["--input-query={}".format(name)]
+ ["--timeout={}".format(timeout)]
+ unknown
)
os.remove(name)
def createQuery(args):
assert args.prop != None
networkPath = args.network
suffix = networkPath.split(".")[-1]
if suffix == "nnet":
network = Marabou.read_nnet(networkPath)
elif suffix == "pb":
network = Marabou.read_tf(networkPath)
elif suffix == "onnx":
network = Marabou.read_onnx(networkPath)
else:
raise WrongNetFormat(networkPath)
query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop)
return query, network
def arguments():
parser = argparse.ArgumentParser(
description="Thin wrapper around Maraboupy executable"
)
parser.add_argument(
"network",
type=str,
nargs="?",
default=None,
help="The network file name, the extension can be only .pb, .nnet, and .onnx",
)
parser.add_argument(
"prop", type=str, nargs="?", default=None, help="The property file name"
)
parser.add_argument(
"-t", "--timeout", type=int, default=10, help="Timeout in seconds"
)
parser.add_argument(
"--temp-dir", type=str, default="/tmp/", help="Temporary directory"
)
marabou_path = shutil.which("Marabou")
parser.add_argument(
"--marabou-binary",
type=str,
default=marabou_path,
help="The path to Marabou binary",
)
parser.add_argument(
"--display-version",
action="store_true",
help="Output Maraboupy version and exit",
)
parser.set_defaults(display_version=False)
return parser
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment