Skip to content
Snippets Groups Projects
Commit fda6e3b3 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[config][test] Maraboupy is its own prover now

parent 1377d844
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,9 @@ directory for licensing information. ...@@ -14,7 +14,9 @@ directory for licensing information.
class WrongNetFormat(Exception): class WrongNetFormat(Exception):
print("The network must be in .pb, .nnet, or .onnx format!") def __init__(self, networkPath):
self.message = f"Network {networkPath} not in onnx, pb or nnet format!"
super().__init__(self.message)
import argparse import argparse
...@@ -22,8 +24,11 @@ import os ...@@ -22,8 +24,11 @@ import os
import pathlib import pathlib
import shutil import shutil
import subprocess import subprocess
from importlib.metadata import version
import sys import sys
import tempfile import tempfile
import warnings
warnings.filterwarnings("ignore")
from maraboupy import Marabou from maraboupy import Marabou
from maraboupy import MarabouCore from maraboupy import MarabouCore
...@@ -36,10 +41,8 @@ def main(): ...@@ -36,10 +41,8 @@ def main():
marabou_binary = args.marabou_binary marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK): if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
print(f"Arguments to the script: {args}")
print(f"Arguments to the Marabou binary: {unknown}")
if args.display_version: if args.display_version:
subprocess.run([marabou_binary] + ["--version"]) print(f"Maraboupy version {version('maraboupy')}")
else: else:
query, _network = createQuery(args) query, _network = createQuery(args)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
...@@ -67,7 +70,7 @@ def createQuery(args): ...@@ -67,7 +70,7 @@ def createQuery(args):
elif suffix == "onnx": elif suffix == "onnx":
network = Marabou.read_onnx(networkPath) network = Marabou.read_onnx(networkPath)
else: else:
raise WrongNetFormat raise WrongNetFormat(networkPath)
query = network.getInputQuery() query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop) MarabouCore.loadProperty(query, args.prop)
return query, network return query, network
......
...@@ -54,11 +54,10 @@ driver = "%{config}/drivers/marabou.drv" ...@@ -54,11 +54,10 @@ driver = "%{config}/drivers/marabou.drv"
use_at_auto_level = 1 use_at_auto_level = 1
[ATP maraboupy] [ATP maraboupy]
name = "Marabou" name = "Maraboupy"
exec = "marabou_eval.py" exec = "marabou_eval.py"
alternative = "maraboupy"
version_switch = "--display-version" version_switch = "--display-version"
version_regexp = "Marabou version \\([0-9.+]+\\)" version_regexp = "Maraboupy version \\([0-9.+]+\\)"
version_ok = "2.0.0" version_ok = "2.0.0"
command = "%e %{nnet-onnx} -q %f " command = "%e %{nnet-onnx} -q %f "
driver = "%{config}/drivers/marabou.drv" driver = "%{config}/drivers/marabou.drv"
......
...@@ -30,11 +30,15 @@ Test autodetect ...@@ -30,11 +30,15 @@ Test autodetect
$ bin/abcrown.sh --version $ bin/abcrown.sh --version
dummy-version dummy-version
$ bin/marabou_eval.py --display-version 2>/dev/null
Maraboupy version 2.0.0
$ caisar config -d $ caisar config -d
AIMOS 1.0 AIMOS 1.0
Alt-Ergo 2.4.0 Alt-Ergo 2.4.0
CVC5 1.0.2 CVC5 1.0.2
Marabou 1.0.+ Marabou 1.0.+
Maraboupy 2.0.0
PyRAT 1.1 PyRAT 1.1
PyRAT 1.1 (ACAS) PyRAT 1.1 (ACAS)
PyRAT 1.1 (ACASd) PyRAT 1.1 (ACASd)
......
#!/usr/bin/env python3
"""
Adapted for quick evaluation of Marabou for CAISAR
Top contributors (to current version):
- Andrew Wu
This file is part of the Marabou project.
Copyright (c) 2017-2021 by the authors listed in the file AUTHORS
in the top-level source directory) and their institutional affiliations.
All rights reserved. See the file COPYING in the top-level source
directory for licensing information.
"""
class WrongNetFormat(Exception):
def __init__(self, networkPath):
self.message = f"Network {networkPath} not in onnx, pb or nnet format!"
super().__init__(self.message)
import argparse
import os
import pathlib
import shutil
import subprocess
from importlib.metadata import version
import sys
import tempfile
from maraboupy import Marabou
from maraboupy import MarabouCore
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
def main():
args, unknown = arguments().parse_known_args()
marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
if args.display_version:
print(f"Maraboupy version {version('maraboupy')}")
else:
query, _network = createQuery(args)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name
timeout = args.timeout
MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown)
subprocess.run(
[marabou_binary]
+ ["--input-query={}".format(name)]
+ ["--timeout={}".format(timeout)]
+ unknown
)
os.remove(name)
def createQuery(args):
assert args.prop != None
networkPath = args.network
suffix = networkPath.split(".")[-1]
if suffix == "nnet":
network = Marabou.read_nnet(networkPath)
elif suffix == "pb":
network = Marabou.read_tf(networkPath)
elif suffix == "onnx":
network = Marabou.read_onnx(networkPath)
else:
raise WrongNetFormat(networkPath)
query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop)
return query, network
def arguments():
parser = argparse.ArgumentParser(
description="Thin wrapper around Maraboupy executable"
)
parser.add_argument(
"network",
type=str,
nargs="?",
default=None,
help="The network file name, the extension can be only .pb, .nnet, and .onnx",
)
parser.add_argument(
"prop", type=str, nargs="?", default=None, help="The property file name"
)
parser.add_argument(
"-t", "--timeout", type=int, default=10, help="Timeout in seconds"
)
parser.add_argument(
"--temp-dir", type=str, default="/tmp/", help="Temporary directory"
)
marabou_path = shutil.which("Marabou")
parser.add_argument(
"--marabou-binary",
type=str,
default=marabou_path,
help="The path to Marabou binary",
)
parser.add_argument(
"--display-version",
action="store_true",
help="Output Maraboupy version and exit",
)
parser.set_defaults(display_version=False)
return parser
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment