From fda6e3b36ff8d30f750b16f4bc5c90bcd2c1d613 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Wed, 15 May 2024 14:56:44 +0200 Subject: [PATCH] [config][test] Maraboupy is its own prover now --- bin/marabou_eval.py | 13 ++-- config/caisar-detection-data.conf | 5 +- tests/autodetect.t | 4 ++ tests/bin/marabou_eval.py | 115 ++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+), 8 deletions(-) mode change 100644 => 100755 bin/marabou_eval.py create mode 100755 tests/bin/marabou_eval.py diff --git a/bin/marabou_eval.py b/bin/marabou_eval.py old mode 100644 new mode 100755 index edddb76..ebc3389 --- a/bin/marabou_eval.py +++ b/bin/marabou_eval.py @@ -14,7 +14,9 @@ directory for licensing information. class WrongNetFormat(Exception): - print("The network must be in .pb, .nnet, or .onnx format!") + def __init__(self, networkPath): + self.message = f"Network {networkPath} not in onnx, pb or nnet format!" + super().__init__(self.message) import argparse @@ -22,8 +24,11 @@ import os import pathlib import shutil import subprocess +from importlib.metadata import version import sys import tempfile +import warnings +warnings.filterwarnings("ignore") from maraboupy import Marabou from maraboupy import MarabouCore @@ -36,10 +41,8 @@ def main(): marabou_binary = args.marabou_binary if not os.access(marabou_binary, os.X_OK): sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) - print(f"Arguments to the script: {args}") - print(f"Arguments to the Marabou binary: {unknown}") if args.display_version: - subprocess.run([marabou_binary] + ["--version"]) + print(f"Maraboupy version {version('maraboupy')}") else: query, _network = createQuery(args) temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) @@ -67,7 +70,7 @@ def createQuery(args): elif suffix == "onnx": network = Marabou.read_onnx(networkPath) else: - raise WrongNetFormat + raise WrongNetFormat(networkPath) query = network.getInputQuery() MarabouCore.loadProperty(query, args.prop) return query, network diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index d925add..813f0c8 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -54,11 +54,10 @@ driver = "%{config}/drivers/marabou.drv" use_at_auto_level = 1 [ATP maraboupy] -name = "Marabou" +name = "Maraboupy" exec = "marabou_eval.py" -alternative = "maraboupy" version_switch = "--display-version" -version_regexp = "Marabou version \\([0-9.+]+\\)" +version_regexp = "Maraboupy version \\([0-9.+]+\\)" version_ok = "2.0.0" command = "%e %{nnet-onnx} -q %f " driver = "%{config}/drivers/marabou.drv" diff --git a/tests/autodetect.t b/tests/autodetect.t index ba5051c..d825d5f 100644 --- a/tests/autodetect.t +++ b/tests/autodetect.t @@ -30,11 +30,15 @@ Test autodetect $ bin/abcrown.sh --version dummy-version + $ bin/marabou_eval.py --display-version 2>/dev/null + Maraboupy version 2.0.0 + $ caisar config -d AIMOS 1.0 Alt-Ergo 2.4.0 CVC5 1.0.2 Marabou 1.0.+ + Maraboupy 2.0.0 PyRAT 1.1 PyRAT 1.1 (ACAS) PyRAT 1.1 (ACASd) diff --git a/tests/bin/marabou_eval.py b/tests/bin/marabou_eval.py new file mode 100755 index 0000000..0b08449 --- /dev/null +++ b/tests/bin/marabou_eval.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Adapted for quick evaluation of Marabou for CAISAR + +Top contributors (to current version): + - Andrew Wu + +This file is part of the Marabou project. +Copyright (c) 2017-2021 by the authors listed in the file AUTHORS +in the top-level source directory) and their institutional affiliations. +All rights reserved. See the file COPYING in the top-level source +directory for licensing information. +""" + + +class WrongNetFormat(Exception): + def __init__(self, networkPath): + self.message = f"Network {networkPath} not in onnx, pb or nnet format!" + super().__init__(self.message) + + +import argparse +import os +import pathlib +import shutil +import subprocess +from importlib.metadata import version +import sys +import tempfile + +from maraboupy import Marabou +from maraboupy import MarabouCore + +sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) + + +def main(): + args, unknown = arguments().parse_known_args() + marabou_binary = args.marabou_binary + if not os.access(marabou_binary, os.X_OK): + sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) + if args.display_version: + print(f"Maraboupy version {version('maraboupy')}") + else: + query, _network = createQuery(args) + temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) + name = temp.name + timeout = args.timeout + MarabouCore.saveQuery(query, name) + print("Running Marabou with the following arguments: ", unknown) + subprocess.run( + [marabou_binary] + + ["--input-query={}".format(name)] + + ["--timeout={}".format(timeout)] + + unknown + ) + os.remove(name) + + +def createQuery(args): + assert args.prop != None + networkPath = args.network + suffix = networkPath.split(".")[-1] + if suffix == "nnet": + network = Marabou.read_nnet(networkPath) + elif suffix == "pb": + network = Marabou.read_tf(networkPath) + elif suffix == "onnx": + network = Marabou.read_onnx(networkPath) + else: + raise WrongNetFormat(networkPath) + query = network.getInputQuery() + MarabouCore.loadProperty(query, args.prop) + return query, network + + +def arguments(): + parser = argparse.ArgumentParser( + description="Thin wrapper around Maraboupy executable" + ) + parser.add_argument( + "network", + type=str, + nargs="?", + default=None, + help="The network file name, the extension can be only .pb, .nnet, and .onnx", + ) + parser.add_argument( + "prop", type=str, nargs="?", default=None, help="The property file name" + ) + parser.add_argument( + "-t", "--timeout", type=int, default=10, help="Timeout in seconds" + ) + parser.add_argument( + "--temp-dir", type=str, default="/tmp/", help="Temporary directory" + ) + marabou_path = shutil.which("Marabou") + parser.add_argument( + "--marabou-binary", + type=str, + default=marabou_path, + help="The path to Marabou binary", + ) + parser.add_argument( + "--display-version", + action="store_true", + help="Output Maraboupy version and exit", + ) + parser.set_defaults(display_version=False) + + return parser + + +if __name__ == "__main__": + main() -- GitLab