Skip to content
Snippets Groups Projects
Commit c1801de3 authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

Merge branch 'feature/michele/maraboupy' into 'master'

Improve integration of maraboupy (Python interface to Marabou)

See merge request laiser/caisar!137
parents e52bc49b 99e351a6
No related branches found
No related tags found
No related merge requests found
...@@ -4,5 +4,5 @@ ...@@ -4,5 +4,5 @@
(files (files
(dummyversion.py as dummyversion.py) (dummyversion.py as dummyversion.py)
(nnenum.sh as nnenum.sh) (nnenum.sh as nnenum.sh)
(marabou_eval.py as marabou_eval.py) (runMarabou.py as runMarabou.py)
(abcrown.sh as abcrown.sh))) (abcrown.sh as abcrown.sh)))
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Adapted for quick evaluation of Marabou for CAISAR Modified by the AISER team, Software Safety and Security Laboratory, CEA-List.
This file is part of CAISAR.
This file is used for integrating Marabou, via its Python interface, in CAISAR.
Top contributors (to current version): Top contributors (to current version):
- Andrew Wu - Andrew Wu
...@@ -12,33 +15,46 @@ All rights reserved. See the file COPYING in the top-level source ...@@ -12,33 +15,46 @@ All rights reserved. See the file COPYING in the top-level source
directory for licensing information. directory for licensing information.
""" """
class WrongNetFormat(Exception):
def __init__(self, networkPath):
self.message = f"Network {networkPath} not in onnx, pb or nnet format!"
super().__init__(self.message)
import argparse import argparse
from importlib.metadata import version
import os import os
import pathlib import pathlib
import shutil import shutil
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from importlib.metadata import version
from maraboupy import Marabou, MarabouCore # type: ignore
from maraboupy import Marabou class WrongNetFormat(Exception):
from maraboupy import MarabouCore def __init__(self, networkPath):
self.message = (
f"Network {networkPath} has an unrecognized extension."
f"The network must be in .pb, .nnet or .onnx format."
)
super().__init__(self.message)
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
def maraboupy_version(marabou_binary: str):
result = subprocess.run([marabou_binary, "--version"], capture_output=True, text=True)
if result.returncode != 0:
print(f"Error running {marabou_binary} --version")
sys.exit(1)
maraboupy_version = version("maraboupy")
output = result.stdout.strip()
if maraboupy_version in output:
return f"maraboupy {maraboupy_version}"
return sys.exit(1)
def arguments(): def arguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Thin wrapper around Maraboupy executable")
description="Thin wrapper around Maraboupy executable"
)
parser.add_argument( parser.add_argument(
"network", "network",
type=str, type=str,
...@@ -46,15 +62,9 @@ def arguments(): ...@@ -46,15 +62,9 @@ def arguments():
default=None, default=None,
help="The network file name, the extension can be only .pb, .nnet, and .onnx", help="The network file name, the extension can be only .pb, .nnet, and .onnx",
) )
parser.add_argument( parser.add_argument("prop", type=str, nargs="?", default=None, help="The property file name")
"prop", type=str, nargs="?", default=None, help="The property file name" parser.add_argument("-t", "--timeout", type=int, default=10, help="Timeout in seconds")
) parser.add_argument("--temp-dir", type=str, default="/tmp/", help="Temporary directory")
parser.add_argument(
"-t", "--timeout", type=int, default=10, help="Timeout in seconds"
)
parser.add_argument(
"--temp-dir", type=str, default="/tmp/", help="Temporary directory"
)
marabou_path = shutil.which("Marabou") marabou_path = shutil.which("Marabou")
parser.add_argument( parser.add_argument(
"--marabou-binary", "--marabou-binary",
...@@ -74,15 +84,17 @@ def arguments(): ...@@ -74,15 +84,17 @@ def arguments():
def main(): def main():
args, unknown = arguments().parse_known_args() args, unknown = arguments().parse_known_args()
marabou_binary = args.marabou_binary marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK): if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary)) sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
else: else:
if args.version: if args.version:
print(f"Maraboupy version {version('maraboupy')}") print(f"{maraboupy_version(marabou_binary)}")
else: else:
assert args.network is not None
assert args.prop is not None
assert args.prop != None
networkPath = args.network networkPath = args.network
suffix = networkPath.split(".")[-1] suffix = networkPath.split(".")[-1]
if suffix == "nnet": if suffix == "nnet":
...@@ -93,18 +105,17 @@ def main(): ...@@ -93,18 +105,17 @@ def main():
network = Marabou.read_onnx(networkPath) network = Marabou.read_onnx(networkPath)
else: else:
raise WrongNetFormat(networkPath) raise WrongNetFormat(networkPath)
query = network.getInputQuery() query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop) MarabouCore.loadProperty(query, args.prop)
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False) temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name name = temp.name
timeout = args.timeout timeout = args.timeout
MarabouCore.saveQuery(query, name) MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown) print("Running Marabou with the following arguments: ", unknown)
subprocess.run( subprocess.run(
[marabou_binary] [marabou_binary] + ["--input-query={}".format(name)] + ["--timeout={}".format(timeout)] + unknown
+ ["--input-query={}".format(name)]
+ ["--timeout={}".format(timeout)]
+ unknown
) )
os.remove(name) os.remove(name)
......
...@@ -3,10 +3,8 @@ opam-version: "2.0" ...@@ -3,10 +3,8 @@ opam-version: "2.0"
version: "1.0" version: "1.0"
synopsis: synopsis:
"A platform for characterizing the safety and robustness of artificial intelligence based software" "A platform for characterizing the safety and robustness of artificial intelligence based software"
maintainer: [ maintainer: ["AISER team, Software Safety and Security Laboratory, CEA-List"]
"LAISER team, Software Safety and Security Laboratory, CEA-List" authors: ["AISER team, Software Safety and Security Laboratory, CEA-List"]
]
authors: ["LAISER team, Software Safety and Security Laboratory, CEA-List"]
license: "LGPL-2.1-only" license: "LGPL-2.1-only"
homepage: "https://git.frama-c.com/pub/caisar" homepage: "https://git.frama-c.com/pub/caisar"
doc: "https://git.frama-c.com/pub/caisar" doc: "https://git.frama-c.com/pub/caisar"
......
...@@ -54,12 +54,12 @@ driver = "%{config}/drivers/marabou.drv" ...@@ -54,12 +54,12 @@ driver = "%{config}/drivers/marabou.drv"
use_at_auto_level = 1 use_at_auto_level = 1
[ATP maraboupy] [ATP maraboupy]
name = "Maraboupy" name = "maraboupy"
exec = "marabou_eval.py" exec = "runMarabou.py"
version_switch = "--version" version_switch = "--version"
version_regexp = "Maraboupy version \\([0-9.]+\\)" version_regexp = "maraboupy \\([0-9.]+\\)"
version_ok = "2.0.0" version_ok = "2.0.0"
command = "%e %{nnet-onnx} %f " command = "%e %{nnet-onnx} %f --timeout %t"
driver = "%{config}/drivers/marabou.drv" driver = "%{config}/drivers/marabou.drv"
use_at_auto_level = 1 use_at_auto_level = 1
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
(generate_opam_files true) (generate_opam_files true)
(license LGPL-2.1-only) (license LGPL-2.1-only)
(authors "LAISER team, Software Safety and Security Laboratory, CEA-List") (authors "AISER team, Software Safety and Security Laboratory, CEA-List")
(maintainers "LAISER team, Software Safety and Security Laboratory, CEA-List") (maintainers "AISER team, Software Safety and Security Laboratory, CEA-List")
(source (uri "git+https://git.frama-c.com/pub/caisar.git")) (source (uri "git+https://git.frama-c.com/pub/caisar.git"))
(bug_reports https://git.frama-c.com/pub/caisar/issues) (bug_reports https://git.frama-c.com/pub/caisar/issues)
(homepage https://git.frama-c.com/pub/caisar) (homepage https://git.frama-c.com/pub/caisar)
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
type t = type t =
| Marabou | Marabou
| Maraboupy [@name "Maraboupy"] | Maraboupy [@name "maraboupy"]
| Pyrat [@name "PyRAT"] | Pyrat [@name "PyRAT"]
| Saver [@name "SAVer"] | Saver [@name "SAVer"]
| Aimos [@name "AIMOS"] | Aimos [@name "AIMOS"]
...@@ -38,6 +38,7 @@ let of_string prover = ...@@ -38,6 +38,7 @@ let of_string prover =
let prover = String.lowercase_ascii prover in let prover = String.lowercase_ascii prover in
match prover with match prover with
| "marabou" -> Some Marabou | "marabou" -> Some Marabou
| "maraboupy" -> Some Maraboupy
| "pyrat" -> Some Pyrat | "pyrat" -> Some Pyrat
| "saver" -> Some Saver | "saver" -> Some Saver
| "aimos" -> Some Aimos | "aimos" -> Some Aimos
...@@ -48,7 +49,7 @@ let of_string prover = ...@@ -48,7 +49,7 @@ let of_string prover =
let to_string = function let to_string = function
| Marabou -> "Marabou" | Marabou -> "Marabou"
| Maraboupy -> "Maraboupy" | Maraboupy -> "maraboupy"
| Pyrat -> "PyRAT" | Pyrat -> "PyRAT"
| Saver -> "SAVer" | Saver -> "SAVer"
| Aimos -> "AIMOS" | Aimos -> "AIMOS"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
type t = private type t = private
| Marabou | Marabou
| Maraboupy [@name "Maraboupy"] | Maraboupy [@name "maraboupy"]
| Pyrat [@name "PyRAT"] | Pyrat [@name "PyRAT"]
| Saver [@name "SAVer"] | Saver [@name "SAVer"]
| Aimos [@name "AIMOS"] | Aimos [@name "AIMOS"]
......
...@@ -237,7 +237,7 @@ let answer_dataset limit config env prover config_prover driver dataset task = ...@@ -237,7 +237,7 @@ let answer_dataset limit config env prover config_prover driver dataset task =
in in
let dataset_answers = let dataset_answers =
match prover with match prover with
| Prover.Marabou -> | Prover.Marabou | Maraboupy ->
(* We turn each task in [dataset_tasks] into a list (ie, conjunction) of (* We turn each task in [dataset_tasks] into a list (ie, conjunction) of
disjunctions of tasks. *) disjunctions of tasks. *)
let tasks = List.map ~f:(Trans.apply Split.split_all) dataset_tasks in let tasks = List.map ~f:(Trans.apply Split.split_all) dataset_tasks in
...@@ -345,7 +345,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task ...@@ -345,7 +345,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
(* Turn [task] into a list (ie, conjunction) of disjunctions of (* Turn [task] into a list (ie, conjunction) of disjunctions of
tasks. *) tasks. *)
match prover with match prover with
| Prover.Marabou -> Trans.apply Split.split_all task | Prover.Marabou | Maraboupy -> Trans.apply Split.split_all task
| Pyrat | Nnenum | ABCrown -> Trans.apply Split.split_premises task | Pyrat | Nnenum | ABCrown -> Trans.apply Split.split_premises task
| _ -> [ task ] | _ -> [ task ]
in in
...@@ -368,7 +368,8 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset ...@@ -368,7 +368,8 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
let proof_strategy = Proof_strategy.apply_svm_prover_strategy in let proof_strategy = Proof_strategy.apply_svm_prover_strategy in
answer_saver limit config env config_prover ~proof_strategy task answer_saver limit config env config_prover ~proof_strategy task
| Aimos -> answer_aimos limit config env config_prover dataset task | Aimos -> answer_aimos limit config env config_prover dataset task
| (Marabou | Pyrat | Nnenum | ABCrown) when Option.is_some dataset -> | (Marabou | Maraboupy | Pyrat | Nnenum | ABCrown)
when Option.is_some dataset ->
let dataset = Unix.realpath (Option.value_exn dataset) in let dataset = Unix.realpath (Option.value_exn dataset) in
answer_dataset limit config env prover config_prover driver dataset task answer_dataset limit config env prover config_prover driver dataset task
| Marabou | Maraboupy | Pyrat | Nnenum | ABCrown -> | Marabou | Maraboupy | Pyrat | Nnenum | ABCrown ->
......
...@@ -30,15 +30,14 @@ Test autodetect ...@@ -30,15 +30,14 @@ Test autodetect
$ bin/abcrown.sh --version $ bin/abcrown.sh --version
dummy-version dummy-version
$ bin/marabou_eval.py --version $ bin/runMarabou.py --version
Maraboupy version 2.0.0 maraboupy 2.0.0
$ caisar config -d $ caisar config -d
AIMOS 1.0 AIMOS 1.0
Alt-Ergo 2.4.0 Alt-Ergo 2.4.0
CVC5 1.0.2 CVC5 1.0.2
Marabou 1.0.+ Marabou 1.0.+
Maraboupy 2.0.0
PyRAT 1.1 PyRAT 1.1
PyRAT 1.1 (ACAS) PyRAT 1.1 (ACAS)
PyRAT 1.1 (ACASd) PyRAT 1.1 (ACASd)
...@@ -47,4 +46,5 @@ Test autodetect ...@@ -47,4 +46,5 @@ Test autodetect
SAVer v1.0 SAVer v1.0
alpha-beta-CROWN dummy-version alpha-beta-CROWN dummy-version
alpha-beta-CROWN dummy-version (ACAS) alpha-beta-CROWN dummy-version (ACAS)
maraboupy 2.0.0
nnenum dummy-version nnenum dummy-version
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
case $1 in case $1 in
--version) --version)
echo "Maraboupy version 2.0.0" echo "maraboupy 2.0.0"
;; ;;
*) *)
echo "PWD: $(pwd)" echo "PWD: $(pwd)"
echo "NN: $2" echo "NN: $1"
test -e $2 || (echo "Cannot find the NN file" && exit 1) test -e $1 || (echo "Cannot find the NN file" && exit 1)
echo "Goal:" echo "Goal:"
cat $4 cat $2
echo "Result = Unknown" echo "Unknown"
esac esac
...@@ -4,7 +4,7 @@ Test verify ...@@ -4,7 +4,7 @@ Test verify
$ bin/Marabou --version $ bin/Marabou --version
1.0.+ 1.0.+
$ caisar verify --format whyml --prover=Marabou --ltag=ProverSpec - <<EOF $ cat > file.mlw <<EOF
> theory T > theory T
> use ieee_float.Float64 > use ieee_float.Float64
> use caisar.types.Vector > use caisar.types.Vector
...@@ -38,6 +38,70 @@ Test verify ...@@ -38,6 +38,70 @@ Test verify
> (nn @@ i)[1] .< (nn @@ i)[0] \/ (nn @@ i)[0] .< (nn @@ i)[1] > (nn @@ i)[1] .< (nn @@ i)[0] \/ (nn @@ i)[0] .< (nn @@ i)[1]
> end > end
> EOF > EOF
$ caisar verify --prover=Marabou --ltag=ProverSpec file.mlw
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
y0 <= 0.0
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
y0 >= 0.5
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
x1 >= 0.5
x1 <= 1.0
y0 <= 0.0
y0 <= 0.5
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
x1 >= 0.5
x1 <= 1.0
y1 <= 0.0
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
x1 >= 0.5
x1 <= 1.0
y1 >= 0.5
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
x1 >= 0.5
x1 <= 1.0
+y1 -y0 >= 0
+y0 -y1 >= 0
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0
x0 <= 0.5
x1 >= 0.5
x1 <= 1.0
+y1 -y0 >= 0
+y0 -y1 >= 0
[DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.75
x0 <= 1.0
x1 >= 0.5
x1 <= 1.0
+y1 -y0 >= 0
+y0 -y1 >= 0
Goal G: Unknown ()
Goal H: Unknown ()
Goal I: Unknown ()
Goal J: Unknown ()
$ caisar verify --prover=maraboupy --ltag=ProverSpec file.mlw
[DEBUG]{ProverSpec} Prover-tailored specification: [DEBUG]{ProverSpec} Prover-tailored specification:
x0 >= 0.0 x0 >= 0.0
x0 <= 0.5 x0 <= 0.5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment