diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c12a5e724500815e81dd241439f85fc92e85c7bd..967d8303d5f3af68fb9b52461a2d75654cf81613 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -45,7 +45,7 @@ test: stage: test script: - nix --extra-experimental-features "nix-command flakes" build - - nix --extra-experimental-features "nix-command flakes" flake check + - nix --extra-experimental-features "nix-command flakes" flake check -L ## Manual generation of the documentation diff --git a/bin/marabou_eval.py b/bin/marabou_eval.py index ebc3389a49598e7e72dc84b5f3b27701cc13c61a..2683464293326a398d4c8c97a34fc0c0426812f4 100755 --- a/bin/marabou_eval.py +++ b/bin/marabou_eval.py @@ -20,60 +20,19 @@ class WrongNetFormat(Exception): import argparse +from importlib.metadata import version import os import pathlib import shutil import subprocess -from importlib.metadata import version import sys import tempfile -import warnings -warnings.filterwarnings("ignore") from maraboupy import Marabou from maraboupy import MarabouCore -sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) - -def main(): - args, unknown = arguments().parse_known_args() - marabou_binary = args.marabou_binary - if not os.access(marabou_binary, os.X_OK): - sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) - if args.display_version: - print(f"Maraboupy version {version('maraboupy')}") - else: - query, _network = createQuery(args) - temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) - name = temp.name - timeout = args.timeout - MarabouCore.saveQuery(query, name) - print("Running Marabou with the following arguments: ", unknown) - subprocess.run( - [marabou_binary] - + ["--input-query={}".format(name)] - + ["--timeout={}".format(timeout)] - + unknown - ) - os.remove(name) - - -def createQuery(args): - assert args.prop != None - networkPath = args.network - suffix = networkPath.split(".")[-1] - if suffix == "nnet": - network = Marabou.read_nnet(networkPath) - elif suffix == "pb": - network = Marabou.read_tf(networkPath) - elif suffix == "onnx": - network = Marabou.read_onnx(networkPath) - else: - raise WrongNetFormat(networkPath) - query = network.getInputQuery() - MarabouCore.loadProperty(query, args.prop) - return query, network +sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) def arguments(): @@ -104,14 +63,51 @@ def arguments(): help="The path to Marabou binary", ) parser.add_argument( - "--display-version", + "--version", action="store_true", - help="Output Maraboupy version and exit", + help="Output a version string if Maraboupy is present and exit", ) - parser.set_defaults(display_version=False) + parser.set_defaults(version=False) return parser +def main(): + args, unknown = arguments().parse_known_args() + marabou_binary = args.marabou_binary + if not os.access(marabou_binary, os.X_OK): + sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) + else: + if args.version: + print(f"Maraboupy version {version('maraboupy')}") + else: + + assert args.prop != None + networkPath = args.network + suffix = networkPath.split(".")[-1] + if suffix == "nnet": + network = Marabou.read_nnet(networkPath) + elif suffix == "pb": + network = Marabou.read_tf(networkPath) + elif suffix == "onnx": + network = Marabou.read_onnx(networkPath) + else: + raise WrongNetFormat(networkPath) + query = network.getInputQuery() + MarabouCore.loadProperty(query, args.prop) + temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) + name = temp.name + timeout = args.timeout + MarabouCore.saveQuery(query, name) + print("Running Marabou with the following arguments: ", unknown) + subprocess.run( + [marabou_binary] + + ["--input-query={}".format(name)] + + ["--timeout={}".format(timeout)] + + unknown + ) + os.remove(name) + + if __name__ == "__main__": main() diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index 813f0c86cc10f133c02abe5ee64bbe023b4a008d..7eb58a8e784beb4d2620006caa45e6fcc6fb12e2 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -56,10 +56,10 @@ use_at_auto_level = 1 [ATP maraboupy] name = "Maraboupy" exec = "marabou_eval.py" -version_switch = "--display-version" -version_regexp = "Maraboupy version \\([0-9.+]+\\)" +version_switch = "--version" +version_regexp = "Maraboupy version \\([0-9.]+\\)" version_ok = "2.0.0" -command = "%e %{nnet-onnx} -q %f " +command = "%e %{nnet-onnx} %f " driver = "%{config}/drivers/marabou.drv" use_at_auto_level = 1 diff --git a/tests/autodetect.t b/tests/autodetect.t index d825d5fd0457d043d6e306ea2c71e63253d1eaae..aa34864e8f786e1b755212f587977d2422c589ea 100644 --- a/tests/autodetect.t +++ b/tests/autodetect.t @@ -30,7 +30,7 @@ Test autodetect $ bin/abcrown.sh --version dummy-version - $ bin/marabou_eval.py --display-version 2>/dev/null + $ bin/marabou_eval.py --version Maraboupy version 2.0.0 $ caisar config -d diff --git a/tests/bin/marabou_eval.py b/tests/bin/marabou_eval.py index 0b084496d6d173d3fee1ad63767553dd2c2fc07f..4ba5fbd8125758a81063b086bcdde7bf2f5c3bac 100755 --- a/tests/bin/marabou_eval.py +++ b/tests/bin/marabou_eval.py @@ -1,115 +1,14 @@ -#!/usr/bin/env python3 -""" -Adapted for quick evaluation of Marabou for CAISAR - -Top contributors (to current version): - - Andrew Wu - -This file is part of the Marabou project. -Copyright (c) 2017-2021 by the authors listed in the file AUTHORS -in the top-level source directory) and their institutional affiliations. -All rights reserved. See the file COPYING in the top-level source -directory for licensing information. -""" - - -class WrongNetFormat(Exception): - def __init__(self, networkPath): - self.message = f"Network {networkPath} not in onnx, pb or nnet format!" - super().__init__(self.message) - - -import argparse -import os -import pathlib -import shutil -import subprocess -from importlib.metadata import version -import sys -import tempfile - -from maraboupy import Marabou -from maraboupy import MarabouCore - -sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) - - -def main(): - args, unknown = arguments().parse_known_args() - marabou_binary = args.marabou_binary - if not os.access(marabou_binary, os.X_OK): - sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) - if args.display_version: - print(f"Maraboupy version {version('maraboupy')}") - else: - query, _network = createQuery(args) - temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) - name = temp.name - timeout = args.timeout - MarabouCore.saveQuery(query, name) - print("Running Marabou with the following arguments: ", unknown) - subprocess.run( - [marabou_binary] - + ["--input-query={}".format(name)] - + ["--timeout={}".format(timeout)] - + unknown - ) - os.remove(name) - - -def createQuery(args): - assert args.prop != None - networkPath = args.network - suffix = networkPath.split(".")[-1] - if suffix == "nnet": - network = Marabou.read_nnet(networkPath) - elif suffix == "pb": - network = Marabou.read_tf(networkPath) - elif suffix == "onnx": - network = Marabou.read_onnx(networkPath) - else: - raise WrongNetFormat(networkPath) - query = network.getInputQuery() - MarabouCore.loadProperty(query, args.prop) - return query, network - - -def arguments(): - parser = argparse.ArgumentParser( - description="Thin wrapper around Maraboupy executable" - ) - parser.add_argument( - "network", - type=str, - nargs="?", - default=None, - help="The network file name, the extension can be only .pb, .nnet, and .onnx", - ) - parser.add_argument( - "prop", type=str, nargs="?", default=None, help="The property file name" - ) - parser.add_argument( - "-t", "--timeout", type=int, default=10, help="Timeout in seconds" - ) - parser.add_argument( - "--temp-dir", type=str, default="/tmp/", help="Temporary directory" - ) - marabou_path = shutil.which("Marabou") - parser.add_argument( - "--marabou-binary", - type=str, - default=marabou_path, - help="The path to Marabou binary", - ) - parser.add_argument( - "--display-version", - action="store_true", - help="Output Maraboupy version and exit", - ) - parser.set_defaults(display_version=False) - - return parser - - -if __name__ == "__main__": - main() +#!/bin/sh -e + +case $1 in + --version) + echo "Maraboupy version 2.0.0" + ;; + *) + echo "PWD: $(pwd)" + echo "NN: $2" + test -e $2 || (echo "Cannot find the NN file" && exit 1) + echo "Goal:" + cat $4 + echo "Result = Unknown" +esac