diff --git a/bin/abcrown.sh b/bin/abcrown.sh new file mode 100755 index 0000000000000000000000000000000000000000..2d02c29fb46caad6e549291f986dee3eb7c61cc2 --- /dev/null +++ b/bin/abcrown.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env sh +########################################################################### +# # +# This file is part of CAISAR. # +# # +# Copyright (C) 2022 # +# CEA (Commissariat à l'énergie atomique et aux énergies # +# alternatives) # +# # +# You can redistribute it and/or modify it under the terms of the GNU # +# Lesser General Public License as published by the Free Software # +# Foundation, version 2.1. # +# # +# It is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU Lesser General Public License for more details. # +# # +# See the GNU Lesser General Public License version 2.1 # +# for more details (enclosed in the file licenses/LGPLv2.1). # +# # +########################################################################### + +if [ "$1" = "--version" ]; then + SCRIPT_DIR=$( dirname -- "$0"; ) + $SCRIPT_DIR/findmodule.py "complete_verifier.abcrown" "dummy-version" +else + python3 -m complete_verifier.abcrown "$@" +fi diff --git a/bin/dune b/bin/dune index 45f3f6d1325e66f8f6a675268bdd0f51ebf1192d..b05c9abcfcdb286a5fae5f216698931c027652a7 100644 --- a/bin/dune +++ b/bin/dune @@ -3,4 +3,5 @@ (section bin) (files (findmodule.py as findmodule.py) - (nnenum.sh as nnenum.sh))) + (nnenum.sh as nnenum.sh) + (abcrown.sh as abcrown.sh))) diff --git a/bin/nnenum.sh b/bin/nnenum.sh index a4748a92bfe2b46b5f1135f65b65510207afbeaa..d58856b5277d9f16fac8dfdd56c72cb25dd210f8 100755 --- a/bin/nnenum.sh +++ b/bin/nnenum.sh @@ -22,8 +22,8 @@ ########################################################################### if [ "$1" = "--version" ]; then - DIRNAME=$( dirname -- "$0"; ) - $DIRNAME/findmodule.py "nnenum" "dummy" + SCRIPT_DIR=$( dirname -- "$0"; ) + $SCRIPT_DIR/findmodule.py "nnenum" "dummy-version" else OMP_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 python3 -m nnenum.nnenum "$@" fi diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index 133ab631e679fac8f51aa5bb8c9a4d26c3a542be..47825ee8ee8aba4c4be6ccdccf79129c051cac80 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -92,12 +92,22 @@ use_at_auto_level = 1 name = "nnenum" exec = "nnenum.sh" version_switch = "--version" -version_regexp = "\\(dummy\\)" -version_ok = "dummy" +version_regexp = "\\(dummy-version\\)" +version_ok = "dummy-version" command = "%e %{nnet-onnx} %f" driver = "caisar_drivers/nnenum.drv" use_at_auto_level = 1 +[ATP abcrown] +name = "alpha-beta-CROWN" +exec = "abcrown.sh" +version_switch = "--version" +version_regexp = "\\(dummy-version\\)" +version_ok = "dummy-version" +command = "%e --device cpu --onnx_path %{nnet-onnx} --vnnlib_path %f --timeout %t" +driver = "caisar_drivers/abcrown.drv" +use_at_auto_level = 1 + [ATP saver] name = "SAVer" exec = "saver" diff --git a/config/drivers/abcrown.drv b/config/drivers/abcrown.drv new file mode 100644 index 0000000000000000000000000000000000000000..ca8f21cd14af98200c9abc6867808ac2833ed8bb --- /dev/null +++ b/config/drivers/abcrown.drv @@ -0,0 +1,32 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +(* Why3 driver for alpha-beta-CROWN *) + +prelude ";;; produced by alpha-beta-CROWN driver" + +import "vnnlib.gen" + +valid "^Result: unsat" +invalid "^Result: sat" +timeout "^Result: timeout" +unknown "^Result: unknown" "" diff --git a/config/dune b/config/dune index 3a77163781bfdc0975e9f907538b6b8885c9779a..3bf2582372a65f9b69fd3891854e84d270a0dd46 100644 --- a/config/dune +++ b/config/dune @@ -11,5 +11,6 @@ (drivers/aimos.drv as drivers/aimos.drv) (drivers/vnnlib.gen as drivers/vnnlib.gen) (drivers/pyrat_vnnlib.drv as drivers/pyrat_vnnlib.drv) - (drivers/nnenum.drv as drivers/nnenum.drv)) + (drivers/nnenum.drv as drivers/nnenum.drv) + (drivers/abcrown.drv as drivers/abcrown.drv)) (package caisar)) diff --git a/src/interpretation.ml b/src/interpretation.ml index aa2b97ff3bd22e629bf5cdb9576d972664fa88d4..4d5af8bad9c4fa51374ba4900e9a3b8ae71b1a5c 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -44,18 +44,13 @@ type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"] type data = D_csv of string list [@@deriving show] -type vector = - (Term.lsymbol - [@printer - fun fmt v -> - Fmt.pf fmt "%a" Fmt.(option ~none:nop int) (Language.lookup_vector v)]) -[@@deriving show] - type caisar_op = | NeuralNetwork of nn | Dataset of dataset | Data of data - | Vector of vector + | Vector of Term.lsymbol + [@printer + fun fmt v -> Fmt.pf fmt "%a" Fmt.(option int) (Language.lookup_vector v)] [@@deriving show] type caisar_env = { @@ -258,8 +253,8 @@ let caisar_builtins : caisar_env CRE.built_in_theories list = let filename = Caml.Filename.concat cwd neural_network in let nn = match id_string with - | "NNet" -> NNet (Language.create_nnet_nn env filename) - | "ONNX" -> ONNX (Language.create_onnx_nn env filename) + | "NNet" -> NNet (Language.create_nn_nnet env filename) + | "ONNX" -> ONNX (Language.create_nn_onnx env filename) | _ -> failwith (Fmt.str "Unrecognized neural network format %s" id_string) in @@ -337,7 +332,7 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = in let new_quant = List.init n ~f:(fun _ -> - let preid = Ident.id_fresh "caisar_v" in + let preid = Ident.id_fresh "caisar_x" in Term.create_vsymbol preid ty) in let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in diff --git a/src/language.ml b/src/language.ml index d525b1f8602c0e17e9327574c3a16d7e0fc244e3..05ad16db554996d592a0978cbb52b4ea4ff9bf78 100644 --- a/src/language.ml +++ b/src/language.ml @@ -160,14 +160,14 @@ let register_ovo_support () = let vectors = Term.Hls.create 10 -let vector_elt_ty env = +let float64_t_ty env = let th = Env.read_theory env [ "ieee_float" ] "Float64" in Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] let create_vector = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module Int) in - let ty_elt = vector_elt_ty env in + let ty_elt = float64_t_ty env in let ty = let th = Env.read_theory env [ "interpretation" ] "Vector" in Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ] @@ -186,8 +186,8 @@ let mem_vector = Term.Hls.mem vectors (* -- Classifier *) type nn = { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty] nn_filename : string; nn_nier : Onnx.G.t option; [@opaque] @@ -204,13 +204,10 @@ let fresh_nn_ls env name = let id = Ident.id_fresh name in Term.create_fsymbol id [] ty -let create_nnet_nn = +let create_nn_nnet = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in - let ty_elt = - let th = Env.read_theory env [ "ieee_float" ] "Float64" in - Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] - in + let ty_elt = float64_t_ty env in Hashtbl.findi_or_add h ~default:(fun filename -> let ls = fresh_nn_ls env "nnet_nn" in let nn = @@ -219,8 +216,8 @@ let create_nnet_nn = | Error s -> Loc.errorm "%s" s | Ok { n_inputs; n_outputs; _ } -> { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = None; @@ -229,10 +226,10 @@ let create_nnet_nn = Term.Hls.add nets ls nn; ls)) -let create_onnx_nn = +let create_nn_onnx = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in - let ty_elt = vector_elt_ty env in + let ty_elt = float64_t_ty env in Hashtbl.findi_or_add h ~default:(fun filename -> let ls = fresh_nn_ls env "onnx_nn" in let onnx = @@ -249,8 +246,8 @@ let create_onnx_nn = | Ok nier -> Some nier in { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = nier; diff --git a/src/language.mli b/src/language.mli index 380465783b7c74eef609ec0f715de98c68dad43c..579255e573ea75eb655c249fd11ca95eaadd6cd2 100644 --- a/src/language.mli +++ b/src/language.mli @@ -72,15 +72,15 @@ val mem_vector : Term.lsymbol -> bool (** -- Neural Network *) type nn = private { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; nn_filename : string; nn_nier : Onnx.G.t option; } [@@deriving show] -val create_nnet_nn : Env.env -> string -> Term.lsymbol -val create_onnx_nn : Env.env -> string -> Term.lsymbol +val create_nn_nnet : Env.env -> string -> Term.lsymbol +val create_nn_onnx : Env.env -> string -> Term.lsymbol val lookup_nn : Term.lsymbol -> nn option val mem_nn : Term.lsymbol -> bool diff --git a/src/printers/pyrat.ml b/src/printers/pyrat.ml index 74d82183c7ca998985d8261bd5109b0c1e7c69d9..d39b7fb22b7e94019ac38d7473d769deab46ed32 100644 --- a/src/printers/pyrat.ml +++ b/src/printers/pyrat.ml @@ -137,11 +137,7 @@ let rec print_goal_term info fmt t = let print_decl info fmt d = match d.Decl.d_node with - | Dtype _ -> () - | Ddata _ -> () - | Dparam _ -> () - | Dlogic _ -> () - | Dind _ -> () + | Dtype _ | Ddata _ | Dparam _ | Dlogic _ | Dind _ -> () | Dprop (Decl.Plemma, _, _) -> assert false | Dprop (Decl.Paxiom, _, f) -> (* PyRAT supports simple axioms only, ie without logical operators. *) diff --git a/src/printers/vnnlib.ml b/src/printers/vnnlib.ml index 29758ee6ce16a920db6930c6c3e6fd5092e28258..45680a9b0371b9e94c53afb0ce8be590a607762d 100644 --- a/src/printers/vnnlib.ml +++ b/src/printers/vnnlib.ml @@ -477,11 +477,13 @@ let print_prop_decl info fmt prop_kind pr t = let print_param_decl info fmt ls = match Term.Hls.find_opt info.variables ls with + | None -> () | Some s -> Fmt.pf fmt ";; %s@\n" s; - Fmt.pf fmt "@[(declare-const %s %a)@]@\n@\n" s (print_type_value info) - ls.ls_value - | _ -> () + (* FIXME: The type should not be hardcoded as 'Real', but printed as: *) + (* Fmt.pf fmt "@[(declare-const %s %a)@]@\n@\n" s (print_type_value info) *) + (* ls.ls_value *) + Fmt.pf fmt "@[(declare-const %s Real)@]@\n@\n" s let print_decl info fmt d = match d.Decl.d_node with diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index ee690421141a1a2f910ef1772862bd122737ece2..e0fa3156946e7f63f4c76878db89e6f34a66660a 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -50,14 +50,10 @@ let apply_classic_prover env task = let trans = Nn2smt.trans env in do_apply_prover ~lookup ~trans [ task ] -let apply_native_nn_prover env task = +let apply_native_nn_prover task = let lookup = Language.lookup_nn in let trans = - Trans.seq - [ - Introduction.introduce_premises; - Native_nn_prover.trans_nn_application env; - ] + Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans ] in let tasks = Trans.apply Split_goal.split_goal_full task in do_apply_prover ~lookup ~trans tasks diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli index 1ec1f823061f8bf766216d777309cc68601a5888..3d9717be700498eaec69960f83488e035d407d0f 100644 --- a/src/proof_strategy.mli +++ b/src/proof_strategy.mli @@ -25,5 +25,5 @@ open Why3 val apply_classic_prover : Env.env -> Task.task -> Task.task list (** Detect and translate applications of neural networks into SMT-LIB. *) -val apply_native_nn_prover : Env.env -> Task.task -> Task.task list +val apply_native_nn_prover : Task.task -> Task.task list (** Detect and execute applications of neural networks. *) diff --git a/src/prover.ml b/src/prover.ml index ee7908d7b7913036b38bf5f512dbfa910b49eda4..bb9ee6ec95178f1d989df85d22d35fc1998b5596 100644 --- a/src/prover.ml +++ b/src/prover.ml @@ -27,9 +27,10 @@ type t = | Aimos [@name "AIMOS"] | CVC5 [@name "cvc5"] | Nnenum [@name "nnenum"] + | ABCrown [@name "alpha-beta-CROWN"] [@@deriving yojson, show] -let list_available () = [ Marabou; Pyrat; Saver; Aimos; CVC5; Nnenum ] +let list_available () = [ Marabou; Pyrat; Saver; Aimos; CVC5; Nnenum; ABCrown ] let of_string prover = let prover = String.lowercase_ascii prover in @@ -40,6 +41,7 @@ let of_string prover = | "aimos" -> Some Aimos | "cvc5" -> Some CVC5 | "nnenum" -> Some Nnenum + | "abcrown" -> Some ABCrown | _ -> None let to_string = function @@ -49,7 +51,8 @@ let to_string = function | Aimos -> "AIMOS" | CVC5 -> "CVC5" | Nnenum -> "nnenum" + | ABCrown -> "alpha-beta-CROWN" let has_vnnlib_support = function | Pyrat -> true - | Marabou | Saver | Aimos | CVC5 | Nnenum -> false + | Marabou | Saver | Aimos | CVC5 | Nnenum | ABCrown -> false diff --git a/src/prover.mli b/src/prover.mli index 38b02037e6cbda8b97e9593f59fa0968babad78f..b993523e327e11a5e418db531f159e647a0829e4 100644 --- a/src/prover.mli +++ b/src/prover.mli @@ -27,6 +27,7 @@ type t = private | Aimos [@name "AIMOS"] | CVC5 [@name "cvc5"] | Nnenum [@name "nnenum"] + | ABCrown [@name "alpha-beta-CROWN"] [@@deriving yojson, show] val list_available : unit -> t list diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 8f7218086b62382be44522e3348664f661cda166..8d3d9439f341d4593c3436ec7022b68806de12c1 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -23,105 +23,135 @@ open Why3 open Base -let get_input_variables = - let add i acc = function - | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls i acc - | arg -> - invalid_arg - (Fmt.str "No direct variable in application: %a" Pretty.print_term arg) +(* Collects in a map the input variables, already declared in a task, and their + indices of appearance inside respective input vectors. Such collecting + process is memoized wrt lsymbols corresponding to input vectors. *) +let collect_input_vars = + let hls = Term.Hls.create 13 in + let add index mls = function + | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls index mls + | t -> failwith (Fmt.str "Not an input variable: %a" Pretty.print_term t) in - let rec aux acc (term : Term.term) = + let rec do_collect mls (term : Term.term) = match term.t_node with | Term.Tapp - ( { ls_name; _ }, - [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) - when String.equal ls_name.id_string (Ident.op_infix "@@") -> ( - match (Language.lookup_nn ls1, Language.lookup_vector ls2) with - | Some { nn_inputs; _ }, Some n -> - assert (nn_inputs = n && n = List.length args); - List.foldi ~init:acc ~f:add args - | _ -> acc) - | _ -> Term.t_fold aux acc term + ( ls1 (* @@ *), + [ + { t_node = Tapp (ls2 (* nn *), _); _ }; + { t_node = Tapp (ls3 (* input vector *), tl (* input vars *)); _ }; + ] ) + when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> ( + match (Language.lookup_nn ls2, Language.lookup_vector ls3) with + | Some { nn_nb_inputs; _ }, Some vector_length -> + assert (nn_nb_inputs = vector_length && vector_length = List.length tl); + if Term.Hls.mem hls ls3 then mls else List.foldi ~init:mls ~f:add tl + | _, _ -> mls) + | _ -> Term.t_fold do_collect mls term in - Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty + Trans.fold_decl + (fun decl mls -> Decl.decl_fold do_collect mls decl) + Term.Mls.empty -(* Create logic symbols for output variables and simplify the formula. *) -let simplify_goal _env input_variables = - let rec aux hls (term : Term.term) = +(* Creates a list of pairs made of output variables and respective indices in + the list, for each neural network application to an input vector appearing in + a task. Such a list stands for the resulting output vector of a neural + network application to an input vector (ie, something of the form: nn@@v). + The creation process is memoized wrt terms corresponding to neural network + applications to input vectors. *) +let create_output_vars = + let rec do_create mt (term : Term.term) = + match term.t_node with + | Term.Tapp (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ]) + when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> ( + match Language.lookup_nn ls2 with + | Some { nn_nb_outputs; nn_ty_elt; _ } -> + if Term.Mterm.mem term mt + then mt + else + let output_vars = + List.init nn_nb_outputs ~f:(fun index -> + ( index, + Term.create_fsymbol (Ident.id_fresh "caisar_y") [] nn_ty_elt )) + in + Term.Mterm.add term output_vars mt + | None -> mt) + | _ -> Term.t_fold do_create mt term + in + Trans.fold_decl + (fun decl mt -> Decl.decl_fold do_create mt decl) + Term.Mterm.empty + +(* Simplifies a task goal exhibiting a vector selection on a neural network + application to an input vector (ie, (nn@@v)[_]) by the corresponding output + variable. Morevoer, each input variable declaration is annotated with a meta + that describes the respective index in the input vector. Ouput variables are + all declared, each with a meta that describes the respective index in the + output vector. *) +let simplify_nn_application input_vars output_vars = + let rec do_simplify (term : Term.term) = match term.t_node with | Term.Tapp - ( ls_vget, + ( ls1 (* [_] *), [ - ({ - t_node = - Tapp - ( ls_apply_nn, - [ - { t_node = Tapp (ls_nn, _); _ }; - { t_node = Tapp (ls_vector, _); _ }; - ] ); - _; - } as _t1); - ({ t_node = Tconst (ConstInt i); _ } as _t2); + ({ t_node = Tapp (ls2 (* @@ *), _); _ } as t1); + ({ t_node = Tconst (ConstInt index); _ } as _t2); ] ) - when String.equal ls_vget.ls_name.id_string (Ident.op_get "") - && String.equal ls_apply_nn.ls_name.id_string (Ident.op_infix "@@") - -> ( - match (Language.lookup_nn ls_nn, Language.lookup_vector ls_vector) with - | Some nn, Some _ -> - let index = Number.to_small_integer i in - let hout = - Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout -> - let create_ls_output () = - let id = Ident.id_fresh "y" in - Term.create_fsymbol id [] nn.nn_ty_elt - in - match hout with - | None -> - let hout = Hashtbl.create (module Int) in - let ls = create_ls_output () in - Hashtbl.add_exn hout ~key:index ~data:ls; - hout - | Some hout -> - Hashtbl.update hout index ~f:(fun lsout -> - match lsout with - | None -> - let ls = create_ls_output () in - Hashtbl.add_exn hout ~key:index ~data:ls; - ls - | Some ls -> ls); - hout) - in - let ls_output = Hashtbl.find_exn hout index in - Term.fs_app ls_output [] nn.nn_ty_elt - | _ -> Term.t_map (aux hls) term) - | _ -> Term.t_map (aux hls) term + when String.equal ls1.ls_name.id_string (Ident.op_get "") + && String.equal ls2.ls_name.id_string (Ident.op_infix "@@") -> ( + match Term.Mterm.find_opt t1 output_vars with + | None -> Term.t_map do_simplify term + | Some output_vars -> + let index = Number.to_small_integer index in + assert (index < List.length output_vars); + let ls = Caml.List.assoc index output_vars in + Term.t_app_infer ls []) + | _ -> Term.t_map do_simplify term in - let htbl = Hashtbl.create (module String) in Trans.fold - (fun task_hd acc -> + (fun task_hd task -> match task_hd.task_decl.td_node with - | Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl - | Decl { d_node = Dparam ls; _ } -> ( - let task = Task.add_tdecl acc task_hd.task_decl in - match Term.Mls.find_opt ls input_variables with - | None -> task - | Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ] - ) - | Decl decl -> - let decl = Decl.decl_map (fun term -> aux htbl term) decl in - let acc = - Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc -> - let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in - Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc -> - let acc = - let decl = Decl.create_param_decl data in - Task.add_decl acc decl - in - Task.add_meta acc Utils.meta_output [ MAls data; MAint key ])) + | Decl { d_node = Dparam ls; _ } -> + (* Add meta for neural network and input variable declarations. Note + that each meta needs to appear before the corresponding declaration + in order to be leveraged by prover printers. *) + let task = + match (Term.Mls.find_opt ls input_vars, Language.lookup_nn ls) with + | None, None -> task + | Some index, None -> + Task.add_meta task Utils.meta_input [ MAls ls; MAint index ] + | None, Some { nn_filename; _ } -> + Task.add_meta task Utils.meta_nn_filename [ MAstr nn_filename ] + | Some _, Some _ -> + (* [ls] cannot be an input variable and a nn at the same time. *) + assert false + in + Task.add_tdecl task task_hd.task_decl + | Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) -> + let decl = Decl.decl_map do_simplify decl in + let task = + (* Output variables are not declared yet in the task as they are + created on the fly for each (different) neural network application + on an input vector. We add here their declarations in the task. *) + Term.Mterm.fold + (fun _t output_vars task -> + (* Again, for each output variable, add the meta first, then its + actual declaration. *) + List.fold output_vars ~init:task + ~f:(fun task (index, output_var) -> + let task = + Task.add_meta task Utils.meta_output + [ MAls output_var; MAint index ] + in + let decl = Decl.create_param_decl output_var in + Task.add_decl task decl)) + output_vars task in - Task.add_decl acc decl) + Task.add_decl task decl + | Use _ | Clone _ | Meta _ | Decl _ -> + Task.add_tdecl task task_hd.task_decl) None -let trans_nn_application env = - Trans.bind get_input_variables (simplify_goal env) +let trans = + Trans.bind collect_input_vars (fun input_vars -> + Trans.bind create_output_vars (fun output_vars -> + simplify_nn_application input_vars output_vars)) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 82cc9c71c748345fbdaa1ec8f40c6f9be708fc3c..2978f4fd9e12d191801991e334cd121fd8c69530 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,4 +20,4 @@ (* *) (**************************************************************************) -val trans_nn_application : Why3.Env.env -> Why3.Task.task Why3.Trans.trans +val trans : Why3.Task.task Why3.Trans.trans diff --git a/src/verification.ml b/src/verification.ml index 8488a8742715fad4844811df1d2c2511ce98ecfb..ce6f2f785397ca3674d272bdd8b631c13fc366a6 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -223,9 +223,9 @@ let answer_dataset limit config env prover config_prover driver dataset task = in (prover_answer, additional_info) -let answer_generic limit config env prover config_prover driver ~proof_strategy - task = - let tasks = proof_strategy env task in +let answer_generic limit config prover config_prover driver ~proof_strategy task + = + let tasks = proof_strategy task in let answers = List.concat_map tasks ~f:(fun task -> let task = Driver.prepare_task driver task in @@ -240,7 +240,7 @@ let answer_generic limit config env prover config_prover driver ~proof_strategy tasks. *) match prover with | Prover.Marabou -> Trans.apply Split.split_all task - | Pyrat | Nnenum -> Trans.apply Split.split_premises task + | Pyrat | Nnenum | ABCrown -> Trans.apply Split.split_premises task | _ -> [ task ] in let command = @@ -259,19 +259,19 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset task match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset task | Aimos -> answer_aimos limit config env config_prover dataset task - | (Marabou | Pyrat | Nnenum) when Option.is_some dataset -> + | (Marabou | Pyrat | Nnenum | ABCrown) when Option.is_some dataset -> let dataset = Unix.realpath (Option.value_exn dataset) in answer_dataset limit config env prover config_prover driver dataset task - | Marabou | Pyrat | Nnenum -> + | Marabou | Pyrat | Nnenum | ABCrown -> let task = Interpretation.interpret_task ~cwd env task in let proof_strategy = Proof_strategy.apply_native_nn_prover in - answer_generic limit config env prover config_prover driver - ~proof_strategy task + answer_generic limit config prover config_prover driver ~proof_strategy + task | CVC5 -> let task = Interpretation.interpret_task ~cwd env task in - let proof_strategy = Proof_strategy.apply_classic_prover in - answer_generic limit config env prover config_prover driver - ~proof_strategy task + let proof_strategy = Proof_strategy.apply_classic_prover env in + answer_generic limit config prover config_prover driver ~proof_strategy + task in let id = Task.task_goal task in { id; prover_answer; additional_info } diff --git a/tests/autodetect.t b/tests/autodetect.t index 702408e7b889e4709ef011732db70c86728eba9d..6c4a2f42f3537c048f43893961845f5d02870233 100644 --- a/tests/autodetect.t +++ b/tests/autodetect.t @@ -4,7 +4,7 @@ Test autodetect > echo "2.4.0" > EOF - $ chmod u+x bin/alt-ergo bin/pyrat.py bin/Marabou bin/saver bin/aimos bin/cvc5 bin/nnenum.sh + $ chmod u+x bin/alt-ergo bin/pyrat.py bin/Marabou bin/saver bin/aimos bin/cvc5 bin/nnenum.sh bin/abcrown.sh $ bin/alt-ergo 2.4.0 @@ -25,19 +25,23 @@ Test autodetect This is cvc5 version 1.0.2 [git tag 1.0.2 branch HEAD] $ bin/nnenum.sh --version - dummy + dummy-version + + $ bin/abcrown.sh --version + dummy-version $ PATH=$(pwd)/bin:$PATH $ caisar config -d -vv 2>&1 | ./filter_tmpdir.sh [caisar][DEBUG] Execution of command 'config' [caisar][DEBUG] Automatic detection + <autodetect>Run: ($TESTCASE_ROOT/bin/nnenum.sh --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/pyrat.py --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/saver --version 2>&1 | head -n1 && (which saver > /dev/null 2>&1)) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/aimos --version) > $TMPFILE 2>&1 + <autodetect>Run: ($TESTCASE_ROOT/bin/abcrown.sh --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/alt-ergo --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/Marabou --version) > $TMPFILE 2>&1 - <autodetect>Run: ($TESTCASE_ROOT/bin/nnenum.sh --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/cvc5 --version) > $TMPFILE 2>&1 <autodetect>0 prover(s) added <autodetect>Generating strategies: @@ -47,10 +51,11 @@ Test autodetect <autodetect>Found prover PyRAT version 1.1, OK. <autodetect>Found prover PyRAT version 1.1 (alternative: VNNLIB) <autodetect>Found prover PyRAT version 1.1 (alternative: ACAS) - <autodetect>Found prover nnenum version dummy, OK. + <autodetect>Found prover nnenum version dummy-version, OK. + <autodetect>Found prover alpha-beta-CROWN version dummy-version, OK. <autodetect>Found prover SAVer version v1.0, OK. <autodetect>Found prover AIMOS version 1.0, OK. - <autodetect>9 prover(s) added + <autodetect>10 prover(s) added [caisar] AIMOS 1.0 Alt-Ergo 2.4.0 CVC5 1.0.2 @@ -59,4 +64,5 @@ Test autodetect PyRAT 1.1 (ACAS) PyRAT 1.1 (VNNLIB) SAVer v1.0 - nnenum dummy + alpha-beta-CROWN dummy-version + nnenum dummy-version diff --git a/tests/bin/abcrown.sh b/tests/bin/abcrown.sh new file mode 100644 index 0000000000000000000000000000000000000000..036b3d59bd303e4f681d1d39992fd3c4b37ba2b7 --- /dev/null +++ b/tests/bin/abcrown.sh @@ -0,0 +1,15 @@ +#!/bin/sh -e + + +case $1 in + --version) + echo "dummy-version" + ;; + *) + echo "PWD: $(pwd)" + echo "NN: $1" + test -e $1 || (echo "Cannot find the NN file" && exit 1) + echo "Goal:" + cat $2 + echo "Unknown" +esac diff --git a/tests/bin/nnenum.sh b/tests/bin/nnenum.sh index 6b0a76a8dc230b662040f7751343b412ff848edd..036b3d59bd303e4f681d1d39992fd3c4b37ba2b7 100644 --- a/tests/bin/nnenum.sh +++ b/tests/bin/nnenum.sh @@ -3,7 +3,7 @@ case $1 in --version) - echo "dummy" + echo "dummy-version" ;; *) echo "PWD: $(pwd)" diff --git a/tests/dune b/tests/dune index 2aed42f48c9297aa4e64908c24103d863736b8af..09ce2af3f16be707a650459cb34e4a57dc6ccf71 100644 --- a/tests/dune +++ b/tests/dune @@ -10,6 +10,7 @@ bin/aimos bin/cvc5 bin/nnenum.sh + bin/abcrown.sh filter_tmpdir.sh ../lib/xgboost/example/california.csv ../lib/xgboost/example/california.json)