diff --git a/config/drivers/marabou.drv b/config/drivers/marabou.drv index 098c2aa6820feda3b7fdf694bc69df06fd03a984..3b14d16a0c57620bd06bbbd438bf7e99d6b37152 100644 --- a/config/drivers/marabou.drv +++ b/config/drivers/marabou.drv @@ -34,6 +34,7 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" +transformation "simplify_rel" transformation "vars_on_lhs" theory BuiltIn @@ -44,11 +45,9 @@ theory BuiltIn meta "eliminate_algebraic" "keep_enums" meta "eliminate_algebraic" "keep_recs" - end theory int.Int - prelude "(* this is a prelude for Alt-Ergo integer arithmetic *)" syntax function zero "0" @@ -87,25 +86,19 @@ theory int.Int remove prop NonTrivialRing remove prop CompatOrderAdd remove prop ZeroLessOne - end theory int.EuclideanDivision - syntax function div "(%1 / %2)" syntax function mod "(%1 % %2)" - end theory int.ComputerDivision - use for_drivers.ComputerOfEuclideanDivision - end theory real.Real - prelude "(* this is a prelude for Alt-Ergo real arithmetic *)" syntax function zero "0.0" @@ -147,10 +140,25 @@ theory real.Real remove prop NonTrivialRing remove prop CompatOrderAdd remove prop ZeroLessOne +end +theory ieee_float.RoundingMode + syntax type mode "RoundingMode" + syntax function RNE "RNE" + syntax function RNA "RNA" + syntax function RTP "RTP" + syntax function RTN "RTN" + syntax function RTZ "RTZ" + + syntax predicate to_nearest "(or (= %1 RNE) (= %1 RNA))" end theory ieee_float.Float64 + syntax function add "(%2 + %3)" + syntax function sub "(%2 - %3)" + syntax function mul "(%2 * %3)" + syntax function div "(%2 / %3)" + syntax function neg "(-%1)" syntax function (.+) "(%1 + %2)" syntax function (.-) "(%1 - %2)" @@ -162,12 +170,9 @@ theory ieee_float.Float64 syntax predicate lt "%1 < %2" syntax predicate ge "%1 >= %2" syntax predicate gt "%1 > %2" - - end theory real.RealInfix - syntax function (+.) "(%1 + %2)" syntax function (-.) "(%1 - %2)" syntax function ( *.) "(%1 * %2)" @@ -183,7 +188,6 @@ theory real.RealInfix syntax predicate (<.) "(%1 < %2)" syntax predicate (>=.) "(%1 >= %2)" syntax predicate (>.) "(%1 > %2)" - end theory Bool diff --git a/config/drivers/pyrat.drv b/config/drivers/pyrat.drv index e8f8c3abcde800f0b1f1f82a99c371200ccf2c4c..ecac7d2a42575d57f16d406541c2c3eb458ea597 100644 --- a/config/drivers/pyrat.drv +++ b/config/drivers/pyrat.drv @@ -34,6 +34,7 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" +transformation "simplify_rel" theory BuiltIn syntax type int "int" @@ -46,7 +47,6 @@ theory BuiltIn end theory int.Int - prelude "(* this is a prelude for Alt-Ergo integer arithmetic *)" syntax function zero "0" @@ -85,25 +85,18 @@ theory int.Int remove prop NonTrivialRing remove prop CompatOrderAdd remove prop ZeroLessOne - end theory int.EuclideanDivision - syntax function div "(%1 / %2)" syntax function mod "(%1 % %2)" - end theory int.ComputerDivision - use for_drivers.ComputerOfEuclideanDivision - end - theory real.Real - prelude "(* this is a prelude for Alt-Ergo real arithmetic *)" syntax function zero "0.0" @@ -145,10 +138,25 @@ theory real.Real remove prop NonTrivialRing remove prop CompatOrderAdd remove prop ZeroLessOne +end + +theory ieee_float.RoundingMode + syntax type mode "RoundingMode" + syntax function RNE "RNE" + syntax function RNA "RNA" + syntax function RTP "RTP" + syntax function RTN "RTN" + syntax function RTZ "RTZ" + syntax predicate to_nearest "(or (= %1 RNE) (= %1 RNA))" end theory ieee_float.Float64 + syntax function add "(%2 + %3)" + syntax function sub "(%2 - %3)" + syntax function mul "(%2 * %3)" + syntax function div "(%2 / %3)" + syntax function neg "(-%1)" syntax function (.+) "(%1 + %2)" syntax function (.-) "(%1 - %2)" @@ -157,14 +165,12 @@ theory ieee_float.Float64 syntax function (.-_) "(-%1)" syntax predicate le "%1 <= %2" - syntax predicate lt "%1 < %2" + syntax predicate lt "%1 < %2" syntax predicate ge "%1 >= %2" - syntax predicate gt "%1 > %2" - + syntax predicate gt "%1 > %2" end theory real.RealInfix - syntax function (+.) "(%1 + %2)" syntax function (-.) "(%1 - %2)" syntax function ( *.) "(%1 * %2)" @@ -180,7 +186,6 @@ theory real.RealInfix syntax predicate (<.) "(%1 < %2)" syntax predicate (>=.) "(%1 >= %2)" syntax predicate (>.) "(%1 > %2)" - end theory Bool diff --git a/config/drivers/saver.drv b/config/drivers/saver.drv index 9f66622300d45f270fb59971b39834b465b677da..730f2d938300b3f32a2ec795177a90f3d0f62fa3 100644 --- a/config/drivers/saver.drv +++ b/config/drivers/saver.drv @@ -21,158 +21,3 @@ (**************************************************************************) (* Why3 drivers for SAVER *) - -theory BuiltIn - syntax type int "int" - syntax type real "real" - - syntax predicate (=) "(%1 = %2)" - - meta "eliminate_algebraic" "keep_enums" - meta "eliminate_algebraic" "keep_recs" - -end - -theory int.Int - - prelude "(* this is a prelude for Alt-Ergo integer arithmetic *)" - - syntax function zero "0" - syntax function one "1" - - syntax function (+) "(%1 + %2)" - syntax function (-) "(%1 - %2)" - syntax function (*) "(%1 * %2)" - syntax function (-_) "(-%1)" - - meta "invalid trigger" predicate (<=) - meta "invalid trigger" predicate (<) - meta "invalid trigger" predicate (>=) - meta "invalid trigger" predicate (>) - - syntax predicate (<=) "(%1 <= %2)" - syntax predicate (<) "(%1 < %2)" - syntax predicate (>=) "(%1 >= %2)" - syntax predicate (>) "(%1 > %2)" - - remove prop MulComm.Comm - remove prop MulAssoc.Assoc - remove prop Unit_def_l - remove prop Unit_def_r - remove prop Inv_def_l - remove prop Inv_def_r - remove prop Assoc - remove prop Mul_distr_l - remove prop Mul_distr_r - remove prop Comm - remove prop Unitary - remove prop Refl - remove prop Trans - remove prop Total - remove prop Antisymm - remove prop NonTrivialRing - remove prop CompatOrderAdd - remove prop ZeroLessOne - -end - -theory int.EuclideanDivision - - syntax function div "(%1 / %2)" - syntax function mod "(%1 % %2)" - -end - -theory int.ComputerDivision - - use for_drivers.ComputerOfEuclideanDivision - -end - - -theory real.Real - - prelude "(* this is a prelude for Alt-Ergo real arithmetic *)" - - syntax function zero "0.0" - syntax function one "1.0" - - syntax function (+) "(%1 + %2)" - syntax function (-) "(%1 - %2)" - syntax function (*) "(%1 * %2)" - syntax function (/) "(%1 / %2)" - syntax function (-_) "(-%1)" - syntax function inv "(1.0 / %1)" - - meta "invalid trigger" predicate (<=) - meta "invalid trigger" predicate (<) - meta "invalid trigger" predicate (>=) - meta "invalid trigger" predicate (>) - - syntax predicate (<=) "(%1 <= %2)" - syntax predicate (<) "(%1 < %2)" - syntax predicate (>=) "(%1 >= %2)" - syntax predicate (>) "(%1 > %2)" - - remove prop MulComm.Comm - remove prop MulAssoc.Assoc - remove prop Unit_def_l - remove prop Unit_def_r - remove prop Inv_def_l - remove prop Inv_def_r - remove prop Assoc - remove prop Mul_distr_l - remove prop Mul_distr_r - remove prop Comm - remove prop Unitary - remove prop Refl - remove prop Trans - remove prop Total - remove prop Antisymm - remove prop Inverse - remove prop NonTrivialRing - remove prop CompatOrderAdd - remove prop ZeroLessOne - -end - -theory ieee_float.Float64 - - syntax function (.+) "(%1 + %2)" - syntax function (.-) "(%1 - %2)" - syntax function (.*) "(%1 * %2)" - syntax function (./) "(%1 / %2)" - syntax function (.-_) "(-%1)" - - syntax predicate le "%1 <= %2" - syntax predicate lt "%1 < %2" - syntax predicate ge "%1 >= %2" - syntax predicate gt "%1 > %2" - - -end - -theory real.RealInfix - - syntax function (+.) "(%1 + %2)" - syntax function (-.) "(%1 - %2)" - syntax function ( *.) "(%1 * %2)" - syntax function (/.) "(%1 / %2)" - syntax function (-._) "(-%1)" - - meta "invalid trigger" predicate (<=.) - meta "invalid trigger" predicate (<.) - meta "invalid trigger" predicate (>=.) - meta "invalid trigger" predicate (>.) - - syntax predicate (<=.) "(%1 <= %2)" - syntax predicate (<.) "(%1 < %2)" - syntax predicate (>=.) "(%1 >= %2)" - syntax predicate (>.) "(%1 > %2)" - -end - -theory ieee_float.Float64 - syntax predicate is_not_nan "" - remove allprops -end diff --git a/config/drivers/vnnlib.gen b/config/drivers/vnnlib.gen index 622e60a778b00dac53b72409ede34d62277390e0..839e7c0b5f1c82667b567460056f96a07571a9b3 100644 --- a/config/drivers/vnnlib.gen +++ b/config/drivers/vnnlib.gen @@ -32,6 +32,7 @@ transformation "inline_trivial" transformation "introduce_premises" transformation "eliminate_builtin" transformation "simplify_formula" +transformation "simplify_rel" transformation "vars_on_lhs" theory BuiltIn @@ -97,7 +98,24 @@ theory Bool meta "encoding:kept" type bool end +theory ieee_float.RoundingMode + syntax type mode "RoundingMode" + syntax function RNE "RNE" + syntax function RNA "RNA" + syntax function RTP "RTP" + syntax function RTN "RTN" + syntax function RTZ "RTZ" + + syntax predicate to_nearest "(or (= %1 RNE) (= %1 RNA))" +end + theory ieee_float.Float64 + syntax function add "(%2 + %3)" + syntax function sub "(%2 - %3)" + syntax function mul "(%2 * %3)" + syntax function div "(%2 / %3)" + syntax function neg "(-%1)" + syntax function (.+) "(%1 + %2)" syntax function (.-) "(%1 - %2)" syntax function (.*) "(%1 * %2)" diff --git a/src/interpretation.ml b/src/interpretation.ml index eb1cdc00bf9db887ab74764ad998bcd5b912beca..aa2b97ff3bd22e629bf5cdb9576d972664fa88d4 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -20,80 +20,348 @@ (* *) (**************************************************************************) -module CRE = Reduction_engine (* Caisar Reduction Engine *) +module CRE = Reduction_engine (* CAISAR Reduction Engine *) open Why3 open Base -type dataset = { dataset : string } [@@deriving show] -type caisar_op = Dataset of dataset [@@deriving show] +type nn = + | NNet of Term.lsymbol + [@printer + fun fmt nn -> + Fmt.pf fmt "NNet: %a" + Fmt.(option Language.pp_nn) + (Language.lookup_nn nn)] + | ONNX of Term.lsymbol + [@printer + fun fmt nn -> + Fmt.pf fmt "ONNX: %a" + Fmt.(option Language.pp_nn) + (Language.lookup_nn nn)] +[@@deriving show] + +type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"] +[@@deriving show] + +type data = D_csv of string list [@@deriving show] + +type vector = + (Term.lsymbol + [@printer + fun fmt v -> + Fmt.pf fmt "%a" Fmt.(option ~none:nop int) (Language.lookup_vector v)]) +[@@deriving show] + +type caisar_op = + | NeuralNetwork of nn + | Dataset of dataset + | Data of data + | Vector of vector +[@@deriving show] type caisar_env = { - dataset_ty : Ty.ty; caisar_op_of_ls : caisar_op Term.Hls.t; ls_of_caisar_op : (caisar_op, Term.lsymbol) Hashtbl.t; + env : Env.env; cwd : string; } -let ls_of_caisar_op engine op = +let ls_of_caisar_op engine caisar_op ty_args ty = let caisar_env = CRE.user_env engine in - Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () -> + Hashtbl.find_or_add caisar_env.ls_of_caisar_op caisar_op ~default:(fun () -> let id = Ident.id_fresh "caisar_op" in - let ty = match op with Dataset _ -> caisar_env.dataset_ty in - let ls = Term.create_fsymbol id [] ty in - Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls; - Term.Hls.add caisar_env.caisar_op_of_ls ls op; + let ls = + match caisar_op with + | NeuralNetwork (NNet n | ONNX n) -> n + | Vector v -> v + | _ -> Term.create_lsymbol id ty_args ty + in + Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:caisar_op ~data:ls; + Term.Hls.add caisar_env.caisar_op_of_ls ls caisar_op; ls) let caisar_op_of_ls engine ls = let caisar_env = CRE.user_env engine in Term.Hls.find caisar_env.caisar_op_of_ls ls -let term_of_caisar_op engine op = - Term.t_app_infer (ls_of_caisar_op engine op) [] +let term_of_caisar_op ?(args = []) engine caisar_op ty = + let t_args, ty_args = List.unzip args in + Term.t_app_infer (ls_of_caisar_op engine caisar_op ty_args ty) t_args let caisar_env env cwd = - let th = Env.read_theory env [ "caisar" ] "Interpretation" in - let ts_dataset = Theory.ns_find_ts th.Theory.th_export [ "dataset" ] in { - dataset_ty = Ty.ty_app ts_dataset []; ls_of_caisar_op = Hashtbl.Poly.create (); caisar_op_of_ls = Term.Hls.create 10; + env; cwd; } -let print_caisar_op fmt caisar_env = +let print_caisar_op_of_ls fmt caisar_env = Pp.print_iter2 Term.Hls.iter Pp.newline Pp.comma Pretty.print_ls pp_caisar_op fmt caisar_env.caisar_op_of_ls + [@@warning "-32"] + +let const_real_of_float value = + let neg = Float.is_negative value in + let value = Fmt.str "%.64f" (Float.abs value) in + (* Split into integer and fractional parts. *) + let int_frac = String.split value ~on:'.' in + let int = List.hd_exn int_frac in + let frac = + match List.tl_exn int_frac with + | [] -> "" + | [ s ] -> s + | _ -> assert false (* Since every float has one '.' at most. *) + in + Constant.ConstReal (Number.real_literal ~radix:10 ~neg ~int ~frac ~exp:None) -let compute_size_of_dataset ~cwd s = - let d = Caml.Filename.concat cwd s in - Array.length (Caml.Sys.readdir d) - -let builtin_caisar : caisar_env CRE.built_in_theories list = - let open_dataset : _ CRE.builtin = - fun engine _ l _ -> - match l with - | [ Term { t_node = Tconst (ConstStr dataset); _ } ] -> - Term (term_of_caisar_op engine (Dataset { dataset })) - | _ -> invalid_arg "We want a string! ;)" +let value_term t = CRE.Value (Term t) +let value_int i = CRE.Value (Int i) + +let caisar_builtins : caisar_env CRE.built_in_theories list = + let reconstruct () = + (* Force the engine to reconstruct the original term. *) + raise Caml.Not_found + in + let error_message ls = + Fmt.str "Invalid arguments for '%a'" Pretty.print_ls ls + in + (* Vector *) + let vget : _ CRE.builtin = + fun engine ls vl ty -> + match vl with + | [ + Term ({ t_node = Tapp (ls1, tl1); _ } as _t1); + Term ({ t_node = Tconst (ConstInt i); _ } as _t2); + ] -> ( + let i = Number.to_small_integer i in + match caisar_op_of_ls engine ls1 with + | Dataset (DS_csv csv) -> + let row = List.nth_exn csv i in + let label, features = + match row with + | [] | [ _ ] -> assert false + | label :: features -> (label, features) + in + let ty_features = + match ty with + | Some { ty_node = Tyapp (_, [ ty; _ ]); _ } -> Some ty + | _ -> assert false + in + let t_features, t_label = + ( term_of_caisar_op engine (Data (D_csv features)) ty_features, + Term.t_int_const (BigInt.of_int (Int.of_string label)) ) + in + value_term (Term.t_tuple [ t_features; t_label ]) + | Vector v -> + let n = Option.value_exn (Language.lookup_vector v) in + assert (List.length tl1 = n && i <= n); + value_term (List.nth_exn tl1 i) + | Data _ | NeuralNetwork _ -> assert false) + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) in - let size : _ CRE.builtin = - fun engine _ l _ -> - match l with + let vlength : _ CRE.builtin = + fun engine ls vl _ty -> + match vl with | [ Term { t_node = Tapp (ls, []); _ } ] -> ( match caisar_op_of_ls engine ls with - | Dataset { dataset } -> - let cwd = (CRE.user_env engine).cwd in - Int (BigInt.of_int (compute_size_of_dataset ~cwd dataset))) - | _ -> invalid_arg "We want a string! ;)" + | Dataset (DS_csv csv) -> value_int (BigInt.of_int (Csv.lines csv)) + | Vector v -> + value_int (BigInt.of_int (Option.value_exn (Language.lookup_vector v))) + | Data (D_csv data) -> value_int (BigInt.of_int (List.length data)) + | NeuralNetwork _ -> assert false) + | [ Term { t_node = Tapp (ls, tl); _ } ] -> ( + match caisar_op_of_ls engine ls with + | Vector v -> + let n = Option.value_exn (Language.lookup_vector v) in + assert (List.length tl = n); + value_int (BigInt.of_int n) + | Dataset _ | Data _ | NeuralNetwork _ -> assert false) + | [ Term _t ] -> reconstruct () + | _ -> invalid_arg (error_message ls) + in + let vminus : _ CRE.builtin = + fun engine ls vl ty -> + match vl with + | [ + Term ({ t_node = Tapp (ls1, tl1); _ } as _t1); + Term ({ t_node = Tapp (ls2, _); _ } as _t2); + ] -> ( + match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with + | Vector v, Data (D_csv data) -> + let n = Option.value_exn (Language.lookup_vector v) in + assert (n = List.length data); + let ty_cst = + match ty with + | Some { ty_node = Tyapp (_, [ ty ]); _ } -> ty + | _ -> assert false + in + let csts = + List.map data ~f:(fun d -> + let cst = const_real_of_float (Float.of_string d) in + Term.t_const cst ty_cst) + in + let { env; _ } = CRE.user_env engine in + let args = + let minus = + (* TODO: generalize wrt the type of constants [csts]. *) + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ]) + in + List.map2_exn tl1 csts ~f:(fun tl c -> + (Term.t_app_infer minus [ tl; c ], ty_cst)) + in + let caisar_op = Vector (Language.create_vector env n) in + value_term (term_of_caisar_op ~args engine caisar_op ty) + | _ -> assert false) + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) + in + let vmapi : _ CRE.builtin = + fun engine ls vl ty -> + match vl with + | [ + Term ({ t_node = Tapp (ls1, tl1); _ } as _t1); + Term ({ t_node = Teps _tb; _ } as t2); + ] -> ( + assert (Term.t_is_lambda t2); + match caisar_op_of_ls engine ls1 with + | Vector v -> + let n = Option.value_exn (Language.lookup_vector v) in + assert (List.length tl1 = n); + let args = + List.mapi tl1 ~f:(fun idx t -> + let idx = Term.t_int_const (BigInt.of_int idx) in + (Term.t_func_app_beta_l t2 [ idx; t ], Option.value_exn t.t_ty)) + in + let caisar_op = + let { env; _ } = CRE.user_env engine in + Vector (Language.create_vector env n) + in + Eval (term_of_caisar_op ~args engine caisar_op ty) + | Dataset (DS_csv csv) -> value_int (BigInt.of_int (Csv.lines csv)) + | Data _ | NeuralNetwork _ -> assert false) + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) in + + (* Neural Network *) + let read_neural_network : _ CRE.builtin = + fun engine ls vl ty -> + match vl with + | [ + Term { t_node = Tconst (ConstStr neural_network); _ }; + Term { t_node = Tapp ({ ls_name = { id_string; _ }; _ }, []); _ }; + ] -> + let { env; cwd; _ } = CRE.user_env engine in + let caisar_op = + let filename = Caml.Filename.concat cwd neural_network in + let nn = + match id_string with + | "NNet" -> NNet (Language.create_nnet_nn env filename) + | "ONNX" -> ONNX (Language.create_onnx_nn env filename) + | _ -> + failwith (Fmt.str "Unrecognized neural network format %s" id_string) + in + NeuralNetwork nn + in + value_term (term_of_caisar_op engine caisar_op ty) + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) + in + let apply_neural_network : _ CRE.builtin = + fun _engine ls vl _ty -> + match vl with + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) + in + + (* Dataset *) + let read_dataset : _ CRE.builtin = + fun engine ls vl ty -> + match vl with + | [ + Term { t_node = Tconst (ConstStr dataset); _ }; + Term { t_node = Tapp ({ ls_name = { id_string = "CSV"; _ }; _ }, []); _ }; + ] -> + let { cwd; _ } = CRE.user_env engine in + let caisar_op = + let filename = Caml.Filename.concat cwd dataset in + let dataset = DS_csv (Csv.load filename) in + Dataset dataset + in + value_term (term_of_caisar_op engine caisar_op ty) + | [ Term _t1; Term _t2 ] -> reconstruct () + | _ -> invalid_arg (error_message ls) + in + [ - ( [ "caisar" ], - "Interpretation", + ( [ "interpretation" ], + "Vector", [], - [ ("open_dataset", None, open_dataset); ("size", None, size) ] ); + [ + ([ Ident.op_get "" ] (* ([]) *), None, vget); + ([ Ident.op_infix "-" ], None, vminus); + ([ "length" ], None, vlength); + ([ "mapi" ], None, vmapi); + ] ); + ( [ "interpretation" ], + "NeuralNetwork", + [], + [ + ([ "read_neural_network" ], None, read_neural_network); + ([ Ident.op_infix "@@" ], None, apply_neural_network); + ] ); + ( [ "interpretation" ], + "Dataset", + [], + [ ([ "read_dataset" ], None, read_dataset) ] ); ] +let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = + match cond.Term.t_node with + | Tapp + ( { ls_name = { id_string = "has_length"; _ }; _ }, + [ + ({ t_node = Tvar vs1; _ } as _t1); + ({ t_node = Tconst (ConstInt n); _ } as _t2); + ] ) -> + if not (Term.vs_equal vs vs1) + then None + else + let n = Number.to_small_integer n in + let ty = + match vs.vs_ty with + | { ty_node = Tyapp (_, ty :: _); _ } -> ty + | _ -> assert false + in + let new_quant = + List.init n ~f:(fun _ -> + let preid = Ident.id_fresh "caisar_v" in + Term.create_vsymbol preid ty) + in + let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in + let caisar_op = + let { env; _ } = CRE.user_env engine in + Vector (Language.create_vector env n) + in + let substitutions = + [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ] + in + Some { new_quant; substitutions } + | _ -> None + +let declare_language_lsymbols caisar_env task = + (* Declare [Language] logic symbols for neural networks and vectors only. *) + Term.Hls.fold + (fun ls _ task -> + if Language.mem_vector ls || Language.mem_nn ls + then + let decl = Decl.create_param_decl ls in + Task.add_decl task decl + else task) + caisar_env.caisar_op_of_ls task + let interpret_task ~cwd env task = let known_map = Task.task_known task in let caisar_env = caisar_env env cwd in @@ -105,24 +373,12 @@ let interpret_task ~cwd env task = compute_max_quantifier_domain = Int.max_value; } in - let engine = CRE.create params env known_map caisar_env builtin_caisar in - let f = Task.task_goal_fmla task in - let f = CRE.normalize ~limit:1000 engine Term.Mvs.empty f in - Fmt.pr "%a : %a@.%a@." Pretty.print_pr (Task.task_goal task) Pretty.print_term - f print_caisar_op caisar_env - -let interpret ?debug ?format ~loadpath file = - let cwd = - match file with - | Verification.File.Stdin -> Unix.getcwd () - | File s -> Caml.Filename.dirname s - | JSON _ -> Unix.getcwd () (*todo *) - in - let env, _config, mstr_theory = - Verification.open_file ?debug ?format ~loadpath file + let engine = + CRE.create ~bounded_quant params env known_map caisar_env caisar_builtins in - Wstdlib.Mstr.iter - (fun _ theory -> - List.iter (Task.split_theory theory None None) ~f:(fun task -> - interpret_task ~cwd env task)) - mstr_theory + let g, f = (Task.task_goal task, Task.task_goal_fmla task) in + let f = CRE.normalize ~limit:Int.max_value engine Term.Mvs.empty f in + let _, task = Task.task_separate_goal task in + let task = declare_language_lsymbols caisar_env task in + let task = Task.(add_prop_decl task Pgoal g f) in + task diff --git a/src/interpretation.mli b/src/interpretation.mli index 68649f274c7cb2c8bd730daad025bff68d8fb658..6252724c03c5cc06c5991f8837e27ae85ef02ccd 100644 --- a/src/interpretation.mli +++ b/src/interpretation.mli @@ -20,11 +20,6 @@ (* *) (**************************************************************************) -open Base +open Why3 -val interpret : - ?debug:bool -> - ?format:string -> - loadpath:string list -> - Verification.File.t -> - unit +val interpret_task : cwd:string -> Env.env -> Task.task -> Task.task diff --git a/src/language.ml b/src/language.ml index 88d81e50d98bdeeac4cb7c73a2a16dcde055753e..d525b1f8602c0e17e9327574c3a16d7e0fc244e3 100644 --- a/src/language.ml +++ b/src/language.ml @@ -155,3 +155,109 @@ let register_onnx_support () = let register_ovo_support () = Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ] (fun env _ filename _ -> ovo_parser env filename) + +(* -- Vector *) + +let vectors = Term.Hls.create 10 + +let vector_elt_ty env = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] + +let create_vector = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module Int) in + let ty_elt = vector_elt_ty env in + let ty = + let th = Env.read_theory env [ "interpretation" ] "Vector" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ] + in + Hashtbl.findi_or_add h ~default:(fun length -> + let ls = + let id = Ident.id_fresh "vector" in + Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty + in + Term.Hls.add vectors ls length; + ls)) + +let lookup_vector = Term.Hls.find_opt vectors +let mem_vector = Term.Hls.mem vectors + +(* -- Classifier *) + +type nn = { + nn_inputs : int; + nn_outputs : int; + nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty] + nn_filename : string; + nn_nier : Onnx.G.t option; [@opaque] +} +[@@deriving show] + +let nets = Term.Hls.create 10 + +let fresh_nn_ls env name = + let ty = + let th = Env.read_theory env [ "interpretation" ] "NeuralNetwork" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "nn" ]) [] + in + let id = Ident.id_fresh name in + Term.create_fsymbol id [] ty + +let create_nnet_nn = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + let ty_elt = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] + in + Hashtbl.findi_or_add h ~default:(fun filename -> + let ls = fresh_nn_ls env "nnet_nn" in + let nn = + let model = Nnet.parse ~permissive:true filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; _ } -> + { + nn_inputs = n_inputs; + nn_outputs = n_outputs; + nn_ty_elt = ty_elt; + nn_filename = filename; + nn_nier = None; + } + in + Term.Hls.add nets ls nn; + ls)) + +let create_onnx_nn = + Env.Wenv.memoize 13 (fun env -> + let h = Hashtbl.create (module String) in + let ty_elt = vector_elt_ty env in + Hashtbl.findi_or_add h ~default:(fun filename -> + let ls = fresh_nn_ls env "onnx_nn" in + let onnx = + let model = Onnx.parse filename in + match model with + | Error s -> Loc.errorm "%s" s + | Ok { n_inputs; n_outputs; nier } -> + let nier = + match nier with + | Error msg -> + Logs.warn (fun m -> + m "Cannot build network intermediate representation:@ %s" msg); + None + | Ok nier -> Some nier + in + { + nn_inputs = n_inputs; + nn_outputs = n_outputs; + nn_ty_elt = ty_elt; + nn_filename = filename; + nn_nier = nier; + } + in + Term.Hls.add nets ls onnx; + ls)) + +let lookup_nn = Term.Hls.find_opt nets +let mem_nn = Term.Hls.mem nets diff --git a/src/language.mli b/src/language.mli index eb025a8a9322cfe5d50dc6dbf8663a92956fac8a..380465783b7c74eef609ec0f715de98c68dad43c 100644 --- a/src/language.mli +++ b/src/language.mli @@ -62,3 +62,25 @@ val onnx_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t (* [ovo_parser env filename] parses and creates the theories corresponding to the given ovo [filename]. The result is memoized. *) + +(** -- Vector *) + +val create_vector : Env.env -> int -> Term.lsymbol +val lookup_vector : Term.lsymbol -> int option +val mem_vector : Term.lsymbol -> bool + +(** -- Neural Network *) + +type nn = private { + nn_inputs : int; + nn_outputs : int; + nn_ty_elt : Ty.ty; + nn_filename : string; + nn_nier : Onnx.G.t option; +} +[@@deriving show] + +val create_nnet_nn : Env.env -> string -> Term.lsymbol +val create_onnx_nn : Env.env -> string -> Term.lsymbol +val lookup_nn : Term.lsymbol -> nn option +val mem_nn : Term.lsymbol -> bool diff --git a/src/main.ml b/src/main.ml index 8dbb9e47281face73e4237a5d9c0b13fcb9e3c5e..7498e10547a2fbc93ed902020d09bbdadc09a669 100644 --- a/src/main.ml +++ b/src/main.ml @@ -25,13 +25,15 @@ open Cmdliner let caisar = "caisar" +let () = + Simplify_rel.init (); + Vars_on_lhs.init () + let () = Pyrat.init (); Marabou.init (); Vnnlib.init () -let () = Vars_on_lhs.init () - (* -- Logs. *) let pp_header = @@ -202,10 +204,6 @@ let verify_json ?memlimit ?timelimit ?outfile json = ~f:(record_verification_result jin.id verification_result) | _ -> failwith "Unexpected more than one theory from a JSON file" -let interpret format loadpath files = - let debug = log_level_is_debug () in - List.iter ~f:(Interpretation.interpret ~debug ?format ~loadpath) files - let verify_xgboost ?memlimit ?timelimit xgboost dataset prover = let memlimit = Option.map memlimit ~f:memlimit_of_string in let timelimit = Option.map timelimit ~f:timelimit_of_string in @@ -362,24 +360,6 @@ let verify_json_cmd = in Cmd.v info term -let interpret_cmd = - let cmdname = "interpret" in - let doc = - "Interpret the goal and print the strategy that will be executed." - in - let info = - Cmd.info cmdname ~sdocs:Manpage.s_common_options ~exits:Cmd.Exit.defaults - ~doc - ~man:[ `S Manpage.s_description; `P doc ] - in - let term = - Term.( - const (fun format loadpath files _ -> - exec_cmd cmdname (fun () -> interpret format loadpath files)) - $ format $ loadpath $ files $ setup_logs) - in - Cmd.v info term - let verify_xgboost_cmd = let cmdname = "verify-xgboost" in let info = @@ -450,18 +430,13 @@ let () = match exn with | Invalid_argument msg -> Fmt.pf fmt "Invalid argument: %s" msg | Failure msg -> Fmt.pf fmt "Failure: %s" msg + | Sys_error msg -> Fmt.pf fmt "%s" msg | _ -> raise exn) let () = try Cmd.group ~default:default_cmd default_info - [ - config_cmd; - verify_cmd; - verify_json_cmd; - interpret_cmd; - verify_xgboost_cmd; - ] + [ config_cmd; verify_cmd; verify_json_cmd; verify_xgboost_cmd ] |> Cmd.eval ~catch:false |> Caml.exit with exn when not (log_level_is_debug ()) -> Logs.err (fun m -> m "@[%a@]" Why3.Exn_printer.exn_printer exn) diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml index 9b70d2e59f94afc1ec7fe22bddd20117458f0293..823782a481186b8ed05bad147c565068ef0cc091 100644 --- a/src/printers/marabou.ml +++ b/src/printers/marabou.ml @@ -76,10 +76,14 @@ let rec print_term info fmt t = with | Some s1, Some s2 -> if Term.ls_equal ls info.ls_rel_float.le + || Term.ls_equal ls info.ls_rel_float.lt || Term.ls_equal ls info.ls_rel_real.le + || Term.ls_equal ls info.ls_rel_real.lt then Fmt.pf fmt "+%s -%s <= 0" s1 s2 else if Term.ls_equal ls info.ls_rel_float.ge + || Term.ls_equal ls info.ls_rel_float.gt || Term.ls_equal ls info.ls_rel_real.ge + || Term.ls_equal ls info.ls_rel_real.gt then Fmt.pf fmt "+%s -%s >= 0" s1 s2 else Printer.unsupportedTerm t "Marabou: unknown relational operator" | _ -> Printer.unsupportedTerm t "Marabou: unknown variable(s)") @@ -116,6 +120,7 @@ let rec negate_term info t = (* Assumption: conjunctions have been split beforehand, hence cannot appear at this stage. *) match t.Term.t_node with + | Tnot t -> t | Tbinop (Tor, t1, t2) -> Term.t_and (negate_term info t1) (negate_term info t2) | Tapp (ls, [ t1; t2 ]) -> diff --git a/src/printers/pyrat.ml b/src/printers/pyrat.ml index 2b2d6aca3c473084cc6da62b56499231f93a4bfd..fd6cf22dfc35a8c8b8f3b1c3762ab362dfacc9f7 100644 --- a/src/printers/pyrat.ml +++ b/src/printers/pyrat.ml @@ -22,7 +22,16 @@ open Why3 +type relops = { + le : Term.lsymbol; + ge : Term.lsymbol; + lt : Term.lsymbol; + gt : Term.lsymbol; +} + type info = { + ls_rel_real : relops; + ls_rel_float : relops; info_syn : Printer.syntax_map; variables : string Term.Hls.t; } @@ -75,9 +84,44 @@ let rec print_premise_term info fmt t = t2 | _ -> if t_is_known info t then Fmt.pf fmt "%a@." (print_base_term info) t +let rec negate_term info t = + match t.Term.t_node with + | Tnot t -> t + | Tbinop (Tand, t1, t2) -> + Term.t_or (negate_term info t1) (negate_term info t2) + | Tbinop (Tor, t1, t2) -> + Term.t_and (negate_term info t1) (negate_term info t2) + | Tapp (ls, [ t1; t2 ]) -> + let tt = [ t1; t2 ] in + (* Negate float relational symbols. *) + let ls_neg = + if Term.ls_equal ls info.ls_rel_float.le + || Term.ls_equal ls info.ls_rel_float.lt + then info.ls_rel_float.ge + else if Term.ls_equal ls info.ls_rel_float.ge + || Term.ls_equal ls info.ls_rel_float.gt + then info.ls_rel_float.le + else ls + in + (* Negate real relational symbols. *) + let ls_neg = + if Term.ls_equal ls info.ls_rel_real.le + || Term.ls_equal ls info.ls_rel_real.lt + then info.ls_rel_real.ge + else if Term.ls_equal ls info.ls_rel_real.ge + || Term.ls_equal ls info.ls_rel_real.gt + then info.ls_rel_real.le + else ls_neg + in + if Term.ls_equal ls_neg ls + then Printer.unsupportedTerm t "PyRAT: cannot negate term" + else Term.ps_app ls_neg tt + | _ -> Printer.unsupportedTerm t "PyRAT: cannot negate term" + let rec print_goal_term info fmt t = match t.Term.t_node with | Tquant _ -> () + | Tnot t -> print_goal_term info fmt (negate_term info t) | Tbinop (((Tand | Tor) as lop), t1, t2) -> if t_is_known info t1 && t_is_known info t2 then @@ -89,7 +133,7 @@ let rec print_goal_term info fmt t = in Fmt.pf fmt "%s%a %s %a%s" psx (print_goal_term info) t1 lop (print_goal_term info) t2 pdx - | _ -> if t_is_known info t then Fmt.pf fmt "%a" (print_base_term info) t + | _ -> if t_is_known info t then Fmt.pf fmt "%a@." (print_base_term info) t let print_decl info fmt d = match d.Decl.d_node with @@ -123,8 +167,26 @@ let rec print_tdecl info fmt task = | Decl d -> print_decl info fmt d) let print_task args ?old:_ fmt task = + let ls_rel_real = + let th = Env.read_theory args.Printer.env [ "real" ] "Real" in + let le = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in + let lt = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in + let ge = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in + let gt = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in + { le; ge; lt; gt } + in + let ls_rel_float = + let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in + let le = Theory.ns_find_ls th.th_export [ "le" ] in + let lt = Theory.ns_find_ls th.th_export [ "lt" ] in + let ge = Theory.ns_find_ls th.th_export [ "ge" ] in + let gt = Theory.ns_find_ls th.th_export [ "gt" ] in + { le; ge; lt; gt } + in let info = { + ls_rel_real; + ls_rel_float; info_syn = Discriminate.get_syntax_map task; variables = Term.Hls.create 10; } diff --git a/src/printers/vnnlib.ml b/src/printers/vnnlib.ml index fe451527ff2cc1b6becc37497b9a7e329d400686..29758ee6ce16a920db6930c6c3e6fd5092e28258 100644 --- a/src/printers/vnnlib.ml +++ b/src/printers/vnnlib.ml @@ -435,6 +435,7 @@ let print_goal info fmt (pr, t) = let rec negate_term info t = match t.Term.t_node with + | Tnot t -> t | Tbinop (Tand, t1, t2) -> Term.t_or (negate_term info t1) (negate_term info t2) | Tbinop (Tor, t1, t2) -> diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index 041bee125873a94a975107134cd04203e1826139..ee690421141a1a2f910ef1772862bd122737ece2 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -20,19 +20,44 @@ (* *) (**************************************************************************) +open Base open Why3 -let do_apply_prover trans task = - let nb = Trans.apply Utils.count_nn_apply task in - match nb with - | 0 -> task - | 1 -> Trans.apply trans task +let set_of_nn_ls ~lookup sls = + let rec aux acc term = + let acc = Term.t_fold aux acc term in + match term.t_node with + | Term.Tapp (ls, _) -> ( + match lookup ls with None -> acc | Some _ -> Term.Sls.add ls acc) + | _ -> acc + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) sls + +let do_apply_prover ~lookup ~trans tasks = + let set_nn_ls = + List.fold tasks ~init:Term.Sls.empty ~f:(fun accum task -> + Trans.apply (set_of_nn_ls ~lookup accum) task) + in + let count_nn_ls = Term.Sls.cardinal set_nn_ls in + match count_nn_ls with + | 0 -> tasks + | 1 -> List.map tasks ~f:(Trans.apply trans) | _ -> invalid_arg "Two or more neural network applications are not supported yet" -let apply_classic_prover env task = do_apply_prover (Nn2smt.trans env) task +let apply_classic_prover env task = + let lookup = Language.lookup_loaded_nets in + let trans = Nn2smt.trans env in + do_apply_prover ~lookup ~trans [ task ] let apply_native_nn_prover env task = - do_apply_prover - (Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans env ]) - task + let lookup = Language.lookup_nn in + let trans = + Trans.seq + [ + Introduction.introduce_premises; + Native_nn_prover.trans_nn_application env; + ] + in + let tasks = Trans.apply Split_goal.split_goal_full task in + do_apply_prover ~lookup ~trans tasks diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli index fb35660055f801966676cb3e5805d0a2a9ef94c0..1ec1f823061f8bf766216d777309cc68601a5888 100644 --- a/src/proof_strategy.mli +++ b/src/proof_strategy.mli @@ -22,8 +22,8 @@ open Why3 -val apply_classic_prover : Env.env -> Task.task -> Task.task +val apply_classic_prover : Env.env -> Task.task -> Task.task list (** Detect and translate applications of neural networks into SMT-LIB. *) -val apply_native_nn_prover : Env.env -> Task.task -> Task.task +val apply_native_nn_prover : Env.env -> Task.task -> Task.task list (** Detect and execute applications of neural networks. *) diff --git a/src/reduction_engine.ml b/src/reduction_engine.ml index cba302c08b0f0147e61d47b1772113209aca8192..9ba8d51b83cadd3fb376d257b2694120758e0776 100644 --- a/src/reduction_engine.ml +++ b/src/reduction_engine.ml @@ -19,9 +19,12 @@ let debug = Debug.register_info_flag ~desc:"" "Reduction_engine" (* {2 Values} *) type value = - | Term of term (* invariant: is in normal form *) - | Int of BigInt.t - | Real of Number.real_value + | Term of term + (* invariant: is in normal form *) + [@printer fun fmt t -> Fmt.pf fmt "%a" Pretty.print_term t] + | Int of BigInt.t [@printer fun fmt t -> Fmt.pf fmt "%i" (BigInt.to_int t)] + | Real of Number.real_value [@printer fun fmt _ -> Fmt.pf fmt "<real>"] +[@@deriving show] (** {2 Environment} *) @@ -34,7 +37,22 @@ type params = { compute_max_quantifier_domain : int; } -type 'a builtin = 'a engine -> lsymbol -> value list -> Ty.ty option -> value +type bounded_quant_result = { + new_quant : vsymbol list; + substitutions : term list; +} + +type builtin_value = + | Eval of Why3.Term.term + [@printer fun fmt t -> Fmt.pf fmt "%a" Pretty.print_term t] + | Value of value +[@@deriving show] + +type 'a builtin = + 'a engine -> lsymbol -> value list -> Ty.ty option -> builtin_value + +and 'a bounded_quant = + 'a engine -> vsymbol -> cond:term -> bounded_quant_result option and 'a engine = { env : Env.env; @@ -44,6 +62,7 @@ and 'a engine = { ls_lt : lsymbol; (* The lsymbol for [int.Int.(<)] *) user_env : 'a; builtins : 'a builtin Hls.t; + bounded_quant : 'a bounded_quant; } let user_env x = x.user_env @@ -107,41 +126,50 @@ let eval_int_op op simpl _ ls l ty = try let n1 = big_int_of_value t1 in let n2 = big_int_of_value t2 in - Int (op n1 n2) + Value (Int (op n1 n2)) with NotNum | Division_by_zero -> simpl ls t1 t2 ty) | _ -> assert false let simpl_add _ls t1 t2 _ty = - if is_zero t1 then t2 else if is_zero t2 then t1 else raise Undetermined + if is_zero t1 + then Value t2 + else if is_zero t2 + then Value t1 + else raise Undetermined -let simpl_sub _ls t1 t2 _ty = if is_zero t2 then t1 else raise Undetermined +let simpl_sub _ls t1 t2 _ty = + if is_zero t2 then Value t1 else raise Undetermined let simpl_mul _ls t1 t2 _ty = if is_zero t1 - then t1 + then Value t1 else if is_zero t2 - then t2 + then Value t2 else if is_one t1 - then t2 + then Value t2 else if is_one t2 - then t1 + then Value t1 else raise Undetermined let simpl_div _ls t1 t2 _ty = if is_zero t2 then raise Undetermined; - if is_zero t1 then t1 else if is_one t2 then t1 else raise Undetermined + if is_zero t1 + then Value t1 + else if is_one t2 + then Value t1 + else raise Undetermined let simpl_mod _ls t1 t2 _ty = if is_zero t2 then raise Undetermined; if is_zero t1 - then t1 + then Value t1 else if is_one t2 - then Int BigInt.zero + then Value (Int BigInt.zero) else raise Undetermined let simpl_minmax _ls v1 v2 _ty = match (v1, v2) with - | Term t1, Term t2 -> if t_equal t1 t2 then v1 else raise Undetermined + | Term t1, Term t2 -> if t_equal t1 t2 then Value v1 else raise Undetermined | _ -> raise Undetermined let eval_int_rel op _ _ls l _ty = @@ -150,7 +178,7 @@ let eval_int_rel op _ _ls l _ty = try let n1 = big_int_of_value t1 in let n2 = big_int_of_value t2 in - Term (to_bool (op n1 n2)) + Value (Term (to_bool (op n1 n2))) with NotNum | Division_by_zero -> raise Undetermined (* t_app_value ls l ty *)) | _ -> assert false @@ -160,7 +188,7 @@ let eval_int_uop op _ _ls l _ty = | [ t1 ] -> ( try let n1 = big_int_of_value t1 in - Int (op n1) + Value (Int (op n1)) with NotNum | Division_by_zero -> raise Undetermined (* t_app_value ls l ty *)) | _ -> assert false @@ -178,7 +206,7 @@ let eval_real_op op simpl _ ls l ty = try let n1 = real_of_value t1 in let n2 = real_of_value t2 in - Real (op n1 n2) + Value (Real (op n1 n2)) with NotNum -> simpl ls t1 t2 ty) | _ -> assert false @@ -187,7 +215,7 @@ let eval_real_uop op _ _ls l _ty = | [ t1 ] -> ( try let n1 = real_of_value t1 in - Real (op n1) + Value (Real (op n1)) with NotNum -> raise Undetermined) | _ -> assert false @@ -210,7 +238,7 @@ let eval_real_rel op _ _ls l _ty = let s2 = real_align ~pow2 ~pow5 n2 in (s1, s2) in - Term (to_bool (op s1 s2)) + Value (Term (to_bool (op s1 s2))) with NotNum -> raise Undetermined (* t_app_value ls l ty *)) | _ -> assert false @@ -266,23 +294,23 @@ let real_mul r1 r2 = let simpl_real_add _ls t1 t2 _ty = if is_real t1 real_0 - then t2 + then Value t2 else if is_real t2 real_0 - then t1 + then Value t1 else raise Undetermined let simpl_real_sub _ls t1 t2 _ty = - if is_real t2 real_0 then t1 else raise Undetermined + if is_real t2 real_0 then Value t1 else raise Undetermined let simpl_real_mul _ls t1 t2 _ty = if is_real t1 real_0 - then t1 + then Value t1 else if is_real t2 real_0 - then t2 + then Value t2 else if is_real t1 real_1 - then t2 + then Value t2 else if is_real t2 real_1 - then t1 + then Value t1 else raise Undetermined let real_pow r1 r2 = @@ -309,14 +337,14 @@ let real_pow r1 r2 = Number.real_value ~pow2 ~pow5 s let simpl_real_pow _ls t1 _t2 _ty = - if is_real t1 real_1 then t1 else raise Undetermined + if is_real t1 real_1 then Value t1 else raise Undetermined let real_from_int _ _ls l _ty = match l with | [ t ] -> ( try let n = big_int_of_value t in - Real (Number.real_value n) + Value (Real (Number.real_value n)) with NotNum -> raise Undetermined) | _ -> assert false @@ -324,7 +352,7 @@ type 'a built_in_theories = Env.pathname * string * (string * (Ty.tysymbol -> unit)) list - * (string * lsymbol ref option * 'a builtin) list + * (string list * lsymbol ref option * 'a builtin) list let built_in_theories : unit -> 'a built_in_theories list = fun () -> @@ -335,54 +363,54 @@ let built_in_theories : unit -> 'a built_in_theories list = "Int", [], [ - (Ident.op_infix "+", None, eval_int_op BigInt.add simpl_add); - (Ident.op_infix "-", None, eval_int_op BigInt.sub simpl_sub); - (Ident.op_infix "*", None, eval_int_op BigInt.mul simpl_mul); - (Ident.op_prefix "-", None, eval_int_uop BigInt.minus); - (Ident.op_infix "<", None, eval_int_rel BigInt.lt); - (Ident.op_infix "<=", None, eval_int_rel BigInt.le); - (Ident.op_infix ">", None, eval_int_rel BigInt.gt); - (Ident.op_infix ">=", None, eval_int_rel BigInt.ge); + ([ Ident.op_infix "+" ], None, eval_int_op BigInt.add simpl_add); + ([ Ident.op_infix "-" ], None, eval_int_op BigInt.sub simpl_sub); + ([ Ident.op_infix "*" ], None, eval_int_op BigInt.mul simpl_mul); + ([ Ident.op_prefix "-" ], None, eval_int_uop BigInt.minus); + ([ Ident.op_infix "<" ], None, eval_int_rel BigInt.lt); + ([ Ident.op_infix "<=" ], None, eval_int_rel BigInt.le); + ([ Ident.op_infix ">" ], None, eval_int_rel BigInt.gt); + ([ Ident.op_infix ">=" ], None, eval_int_rel BigInt.ge); ] ); ( [ "int" ], "MinMax", [], [ - ("min", None, eval_int_op BigInt.min simpl_minmax); - ("max", None, eval_int_op BigInt.max simpl_minmax); + ([ "min" ], None, eval_int_op BigInt.min simpl_minmax); + ([ "max" ], None, eval_int_op BigInt.max simpl_minmax); ] ); ( [ "int" ], "ComputerDivision", [], [ - ("div", None, eval_int_op BigInt.computer_div simpl_div); - ("mod", None, eval_int_op BigInt.computer_mod simpl_mod); + ([ "div" ], None, eval_int_op BigInt.computer_div simpl_div); + ([ "mod" ], None, eval_int_op BigInt.computer_mod simpl_mod); ] ); ( [ "int" ], "EuclideanDivision", [], [ - ("div", None, eval_int_op BigInt.euclidean_div simpl_div); - ("mod", None, eval_int_op BigInt.euclidean_mod simpl_mod); + ([ "div" ], None, eval_int_op BigInt.euclidean_div simpl_div); + ([ "mod" ], None, eval_int_op BigInt.euclidean_mod simpl_mod); ] ); ( [ "real" ], "Real", [], [ - (Ident.op_prefix "-", None, eval_real_uop real_opp); - (Ident.op_infix "+", None, eval_real_op real_add simpl_real_add); - (Ident.op_infix "-", None, eval_real_op real_sub simpl_real_sub); - (Ident.op_infix "*", None, eval_real_op real_mul simpl_real_mul); - (Ident.op_infix "<", None, eval_real_rel BigInt.lt); - (Ident.op_infix "<=", None, eval_real_rel BigInt.le); - (Ident.op_infix ">", None, eval_real_rel BigInt.gt); - (Ident.op_infix ">=", None, eval_real_rel BigInt.ge); + ([ Ident.op_prefix "-" ], None, eval_real_uop real_opp); + ([ Ident.op_infix "+" ], None, eval_real_op real_add simpl_real_add); + ([ Ident.op_infix "-" ], None, eval_real_op real_sub simpl_real_sub); + ([ Ident.op_infix "*" ], None, eval_real_op real_mul simpl_real_mul); + ([ Ident.op_infix "<" ], None, eval_real_rel BigInt.lt); + ([ Ident.op_infix "<=" ], None, eval_real_rel BigInt.le); + ([ Ident.op_infix ">" ], None, eval_real_rel BigInt.gt); + ([ Ident.op_infix ">=" ], None, eval_real_rel BigInt.ge); ] ); - ([ "real" ], "FromInt", [], [ ("from_int", None, real_from_int) ]); + ([ "real" ], "FromInt", [], [ ([ "from_int" ], None, real_from_int) ]); ( [ "real" ], "PowerReal", [], - [ ("pow", None, eval_real_op real_pow simpl_real_pow) ] ); + [ ([ "pow" ], None, eval_real_op real_pow simpl_real_pow) ] ); (* ["map"],"Map", ["map", builtin_map_type], [ "const", Some ls_map_const, eval_map_const; "get", Some ls_map_get, eval_map_get; "set", Some ls_map_set, eval_map_set; ] ; *) @@ -397,7 +425,7 @@ let add_builtin_th env ((l, n, t, d) : 'a built_in_theories) = t; List.iter (fun (id, r, f) -> - let ls = Theory.ns_find_ls th.Theory.th_export [ id ] in + let ls = Theory.ns_find_ls th.Theory.th_export id in Hls.add env.builtins ls f; match r with None -> () | Some r -> r := ls) d @@ -717,7 +745,7 @@ let bounds ls_lt t1 t2 = - expand to bounded quantifications on range values - compatiblity with reverse direction (forall i. b > i > a -> t) - detect SPARK-style [forall i. if a < i /\ i < b then t else true] *) -let reduce_bounded_quant ls_lt limit t sigma st rem = +let reduce_bounded_quant engine ls_lt limit t sigma st rem = match (st, rem) with (* st = a < vs < b :: _; rem = -> :: forall vs :: _ *) | ( Term { t_node = Tbinop (Tand, t1, t2) } :: st, @@ -750,6 +778,37 @@ let reduce_bounded_quant ls_lt limit t sigma st rem = loop rem (BigInt.pred i) in { value_stack = st; cont_stack = loop rem b } + | ( Term cond :: st, + (Kbinop Timplies, _) + :: (Kquant ((Tforall as quant), [ vs ], _), t_orig) + :: rem ) -> ( + match engine.bounded_quant engine vs ~cond with + | None -> raise Exit + | Some res -> ( + let t_empty, binop = + match quant with Tforall -> (t_true, Tand) | Texists -> (t_false, Tor) + in + let rem = + match res.new_quant with + | [] -> rem + | _ -> (Kquant (quant, res.new_quant, []), t_orig) :: rem + in + let rec loop rem = function + | [] -> rem + | t_i :: l -> + let rem = + match l with + | [] -> rem + | _ -> + (* conjunction *) + (Kbinop binop, t_true) :: rem + in + let rem = (Keval (t, Mvs.add vs t_i sigma), t_true) :: rem in + loop rem l + in + match res.substitutions with + | [] -> { value_stack = Term t_empty :: st; cont_stack = rem } + | _ -> { value_stack = st; cont_stack = loop rem res.substitutions })) | _ -> raise Exit let rec reduce engine c = @@ -757,7 +816,7 @@ let rec reduce engine c = | _, [] -> assert false | st, (Keval (t, sigma), orig) :: rem -> ( let limit = engine.params.compute_max_quantifier_domain in - try reduce_bounded_quant engine.ls_lt limit t sigma st rem + try reduce_bounded_quant engine engine.ls_lt limit t sigma st rem with Exit -> reduce_eval engine st t ~orig sigma rem) | [], (Kif _, _) :: _ -> assert false | v :: st, (Kif (t2, t3, sigma), orig) :: rem -> ( @@ -1071,7 +1130,14 @@ and reduce_app_no_equ engine st ls ~orig ty rem_cont = try let f = Hls.find engine.builtins ls in let v = f engine ls args ty in - { value_stack = v_attr_copy orig v :: rem_st; cont_stack = rem_cont } + match v with + | Value v -> + { value_stack = v_attr_copy orig v :: rem_st; cont_stack = rem_cont } + | Eval t -> + { + value_stack = rem_st; + cont_stack = (Keval (t, Mvs.empty), t) :: rem_cont; + } with Not_found | Undetermined -> ( let args = List.map term_of_value args in match Ident.Mid.find ls.ls_name engine.known_map with @@ -1325,7 +1391,8 @@ let normalize ?step_limit ~limit engine sigma t0 = (* the rewrite engine *) -let create p env km user_env built_in_user = +let create ?(bounded_quant = fun _ _ ~cond:_ -> None) p env km user_env + built_in_user = let th = Env.read_theory env [ "int" ] "Int" in let ls_lt = Theory.ns_find_ls th.Theory.th_export [ Ident.op_infix "<" ] in let env = @@ -1337,6 +1404,7 @@ let create p env km user_env built_in_user = ls_lt; builtins = Hls.create 17; user_env; + bounded_quant; } in if p.compute_builtin then get_builtins env built_in_user; diff --git a/src/reduction_engine.mli b/src/reduction_engine.mli index 53407b21dc4ffbc5d16c4fe31d0ae29f2c474ded..1d8e3cf386ca0a3b56c952460d92997915988553 100644 --- a/src/reduction_engine.mli +++ b/src/reduction_engine.mli @@ -89,17 +89,32 @@ type value = | Term of Why3.Term.term (* invariant: is in normal form *) | Int of BigInt.t | Real of Number.real_value +[@@deriving show] + +type builtin_value = + | Eval of Why3.Term.term + | Value of value +[@@deriving show] type 'a builtin = - 'a engine -> Why3.Term.lsymbol -> value list -> Ty.ty option -> value + 'a engine -> Why3.Term.lsymbol -> value list -> Ty.ty option -> builtin_value type 'a built_in_theories = Env.pathname * string * (string * (Ty.tysymbol -> unit)) list - * (string * Why3.Term.lsymbol ref option * 'a builtin) list + * (string list * Why3.Term.lsymbol ref option * 'a builtin) list + +type bounded_quant_result = { + new_quant : Term.vsymbol list; + substitutions : Term.term list; +} + +type 'a bounded_quant = + 'a engine -> Term.vsymbol -> cond:Term.term -> bounded_quant_result option val create : + ?bounded_quant:'a bounded_quant -> params -> Env.env -> Decl.decl Ident.Mid.t -> diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 3edcb00a2093e13e255339d129d018387157ccea..8f7218086b62382be44522e3348664f661cda166 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -23,26 +23,81 @@ open Why3 open Base +let get_input_variables = + let add i acc = function + | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls i acc + | arg -> + invalid_arg + (Fmt.str "No direct variable in application: %a" Pretty.print_term arg) + in + let rec aux acc (term : Term.term) = + match term.t_node with + | Term.Tapp + ( { ls_name; _ }, + [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) + when String.equal ls_name.id_string (Ident.op_infix "@@") -> ( + match (Language.lookup_nn ls1, Language.lookup_vector ls2) with + | Some { nn_inputs; _ }, Some n -> + assert (nn_inputs = n && n = List.length args); + List.foldi ~init:acc ~f:add args + | _ -> acc) + | _ -> Term.t_fold aux acc term + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty + (* Create logic symbols for output variables and simplify the formula. *) -(* TODO: [Reduction_engine] is probably an overkill and should be replaced. *) -let simplify_goal env input_variables = - let rec aux meta hls (term : Term.term) = +let simplify_goal _env input_variables = + let rec aux hls (term : Term.term) = match term.t_node with - | Term.Tapp (ls, _) -> ( - match Language.lookup_loaded_nets ls with - | Some nn -> - meta := nn.filename :: !meta; - let outputs = - List.init nn.nb_outputs ~f:(fun i -> - let id = Ident.id_fresh "y" in - let ls = Term.create_fsymbol id [] nn.ty_data in - hls := (Decl.create_param_decl ls, ls, i) :: !hls; - Term.fs_app ls [] nn.ty_data) + | Term.Tapp + ( ls_vget, + [ + ({ + t_node = + Tapp + ( ls_apply_nn, + [ + { t_node = Tapp (ls_nn, _); _ }; + { t_node = Tapp (ls_vector, _); _ }; + ] ); + _; + } as _t1); + ({ t_node = Tconst (ConstInt i); _ } as _t2); + ] ) + when String.equal ls_vget.ls_name.id_string (Ident.op_get "") + && String.equal ls_apply_nn.ls_name.id_string (Ident.op_infix "@@") + -> ( + match (Language.lookup_nn ls_nn, Language.lookup_vector ls_vector) with + | Some nn, Some _ -> + let index = Number.to_small_integer i in + let hout = + Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout -> + let create_ls_output () = + let id = Ident.id_fresh "y" in + Term.create_fsymbol id [] nn.nn_ty_elt + in + match hout with + | None -> + let hout = Hashtbl.create (module Int) in + let ls = create_ls_output () in + Hashtbl.add_exn hout ~key:index ~data:ls; + hout + | Some hout -> + Hashtbl.update hout index ~f:(fun lsout -> + match lsout with + | None -> + let ls = create_ls_output () in + Hashtbl.add_exn hout ~key:index ~data:ls; + ls + | Some ls -> ls); + hout) in - Term.t_tuple outputs - | _ -> Term.t_map (aux meta hls) term) - | _ -> Term.t_map (aux meta hls) term + let ls_output = Hashtbl.find_exn hout index in + Term.fs_app ls_output [] nn.nn_ty_elt + | _ -> Term.t_map (aux hls) term) + | _ -> Term.t_map (aux hls) term in + let htbl = Hashtbl.create (module String) in Trans.fold (fun task_hd acc -> match task_hd.task_decl.td_node with @@ -54,43 +109,19 @@ let simplify_goal env input_variables = | Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ] ) | Decl decl -> - let meta = ref [] in - let hls = ref [] in - let decl = - Decl.decl_map - (fun term -> - let term = aux meta hls term in - if List.is_empty !hls - then term - else - let known = - List.fold !hls ~init:task_hd.task_known - ~f:(fun acc (d, _, _) -> Decl.known_add_decl acc d) - in - let engine = - Reduction_engine.create - { - compute_defs = false; - compute_builtin = true; - compute_def_set = Term.Sls.empty; - compute_max_quantifier_domain = 0; - } - env known - in - Reduction_engine.normalize ~limit:100 engine Term.Mvs.empty term) - decl - in - let acc = - List.fold !hls ~init:acc ~f:(fun acc (d, ls, i) -> - let task = Task.add_decl acc d in - Task.add_meta task Utils.meta_output [ MAls ls; MAint i ]) - in + let decl = Decl.decl_map (fun term -> aux htbl term) decl in let acc = - List.fold !meta ~init:acc ~f:(fun acc s -> - Task.add_meta acc Utils.meta_nn_filename [ MAstr s ]) + Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc -> + let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in + Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc -> + let acc = + let decl = Decl.create_param_decl data in + Task.add_decl acc decl + in + Task.add_meta acc Utils.meta_output [ MAls data; MAint key ])) in Task.add_decl acc decl) None -let trans env = - Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ] +let trans_nn_application env = + Trans.bind get_input_variables (simplify_goal env) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 936ff492295cfd2682bca8c47d65365add8fd67e..82cc9c71c748345fbdaa1ec8f40c6f9be708fc3c 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,6 +20,4 @@ (* *) (**************************************************************************) -open Why3 - -val trans : Env.env -> Task.task Trans.trans +val trans_nn_application : Why3.Env.env -> Why3.Task.task Why3.Trans.trans diff --git a/src/transformations/simplify_rel.ml b/src/transformations/simplify_rel.ml new file mode 100644 index 0000000000000000000000000000000000000000..231676a619012c2e15d4a0a7b9b1cb3c64659757 --- /dev/null +++ b/src/transformations/simplify_rel.ml @@ -0,0 +1,145 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +let float_of_real_constant rc = + let is_neg, rc = + (BigInt.lt rc.Number.rl_real.rv_sig BigInt.zero, Number.abs_real rc) + in + let rc_str = Fmt.str "%a" Number.(print_real_constant full_support) rc in + let f = Float.of_string rc_str in + if is_neg then Float.neg f else f + +let real_constant_of_float value = + let neg = Float.is_negative value in + let value = Fmt.str "%.64f" (Float.abs value) in + (* Split into integer and fractional parts. *) + let int_frac = String.split value ~on:'.' in + let int = List.hd_exn int_frac in + let frac = + match List.tl_exn int_frac with + | [] -> "" + | [ s ] -> s + | _ -> assert false (* Since every float has one '.' at most. *) + in + Constant.ConstReal (Number.real_literal ~radix:10 ~neg ~int ~frac ~exp:None) + +let term_of_float env f = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + let ty = Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] in + Term.t_const (real_constant_of_float f) ty + +let make_rt env = + let th = Env.read_theory env [ "ieee_float" ] "Float64" in + let ty = Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] in + let le_fp = Theory.ns_find_ls th.th_export [ "le" ] in + let lt_fp = Theory.ns_find_ls th.th_export [ "lt" ] in + let ge_fp = Theory.ns_find_ls th.th_export [ "ge" ] in + let gt_fp = Theory.ns_find_ls th.th_export [ "gt" ] in + let rel_fp = [ le_fp; lt_fp; ge_fp; gt_fp ] in + let add_fp = Theory.ns_find_ls th.th_export [ "add" ] in + let sub_fp = Theory.ns_find_ls th.th_export [ "sub" ] in + let mul_fp = Theory.ns_find_ls th.th_export [ "mul" ] in + let div_fp = Theory.ns_find_ls th.th_export [ "div" ] in + let neg_fp = Theory.ns_find_ls th.th_export [ "neg" ] in + let op_fp = [ add_fp; sub_fp; mul_fp; div_fp; neg_fp ] in + let rec rt t = + let t = Term.t_map rt t in + match t.t_node with + | Tapp (ls, [ { t_node = Tconst (ConstReal rc); _ } ]) + when Term.ls_equal ls neg_fp -> + let rc = Number.neg_real rc in + Term.t_const (ConstReal rc) ty + | Tapp + ( ls_rel, + [ + { t_node = Tconst (ConstReal rc1); _ }; + { + t_node = + Tapp + (ls_op, [ _mode; t'; { t_node = Tconst (ConstReal rc2); _ } ]); + _; + }; + ] ) + when List.exists ~f:(Term.ls_equal ls_rel) rel_fp + && List.exists ~f:(Term.ls_equal ls_op) op_fp -> + let rc1_float = float_of_real_constant rc1 in + let rc2_float = float_of_real_constant rc2 in + let op_float = + if Term.ls_equal ls_op sub_fp + then Float.( + ) + else if Term.ls_equal ls_op add_fp + then Float.( - ) + else if Term.ls_equal ls_op mul_fp + then Float.( / ) + else if Term.ls_equal ls_op div_fp + then Float.( * ) + else assert false + in + let rc_float = op_float rc1_float rc2_float in + let rc_t = term_of_float env rc_float in + let t = Term.t_app_infer ls_rel [ rc_t; t' ] in + rt t + | Tapp + ( ls_rel, + [ + { + t_node = + Tapp + (ls_op, [ _mode; t'; { t_node = Tconst (ConstReal rc2); _ } ]); + _; + }; + { t_node = Tconst (ConstReal rc1); _ }; + ] ) + when List.exists ~f:(Term.ls_equal ls_rel) rel_fp + && List.exists ~f:(Term.ls_equal ls_op) op_fp -> + let rc1_float = float_of_real_constant rc1 in + let rc2_float = float_of_real_constant rc2 in + let op_float = + if Term.ls_equal ls_op sub_fp + then Float.( + ) + else if Term.ls_equal ls_op add_fp + then Float.( - ) + else if Term.ls_equal ls_op mul_fp + then Float.( / ) + else if Term.ls_equal ls_op div_fp + then Float.( * ) + else assert false + in + let rc_float = op_float rc1_float rc2_float in + let rc_t = term_of_float env rc_float in + let t = Term.t_app_infer ls_rel [ t'; rc_t ] in + rt t + | _ -> t + in + rt + +let simplify_rel env = + let rt = make_rt env in + Trans.rewrite rt None + +let init () = + Trans.register_env_transform + ~desc:"Simplify linear inequalities with float values." "simplify_rel" + simplify_rel diff --git a/src/transformations/simplify_rel.mli b/src/transformations/simplify_rel.mli new file mode 100644 index 0000000000000000000000000000000000000000..694097b3b0dcc9bef16af63c6f16994dd897e63e --- /dev/null +++ b/src/transformations/simplify_rel.mli @@ -0,0 +1,24 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +val init : unit -> unit +(** Register the transformation. *) diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml index 1cfbdda2bbcec84da22d9b757b02f1d0106cb395..fcb09d678c3eb7a1960ffe54c8034ad40bdcf74e 100644 --- a/src/transformations/utils.ml +++ b/src/transformations/utils.ml @@ -21,38 +21,6 @@ (**************************************************************************) open Why3 -open Base - -let count_nn_apply = - let rec aux acc (term : Term.term) = - let acc = Term.t_fold aux acc term in - match term.t_node with - | Term.Tapp (ls, _) -> ( - match Language.lookup_loaded_nets ls with - | None -> acc - | Some _ -> acc + 1) - | _ -> acc - in - Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0 - -let get_input_variables = - let rec aux acc (term : Term.term) = - match term.t_node with - | Term.Tapp (ls, args) -> ( - match Language.lookup_loaded_nets ls with - | None -> acc - | Some _ -> - let add i acc = function - | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc - | arg -> - invalid_arg - (Fmt.str "No direct variable in application: %a" Pretty.print_term - arg) - in - List.foldi ~init:acc ~f:add args) - | _ -> Term.t_fold aux acc term - in - Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty let meta_input = Theory.( diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli index 12b8c1348d0a3763d44c6ef16280c03fc8408dc9..9a48140ab2b7ff77da6580c6acde8437e9c4787b 100644 --- a/src/transformations/utils.mli +++ b/src/transformations/utils.mli @@ -22,12 +22,6 @@ open Why3 -val count_nn_apply : int Trans.trans -(** Count the number of applications of [nn_apply]. *) - -val get_input_variables : int Term.Mls.t Trans.trans -(** Retrieve the input variables appearing as arguments of [nn_apply]. *) - val meta_input : Theory.meta (** Indicate the input position. *) diff --git a/src/verification.ml b/src/verification.ml index 280c22d7d2ff094147a2a1b936bbb260ffaa6e62..8488a8742715fad4844811df1d2c2511ce98ecfb 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -223,31 +223,38 @@ let answer_dataset limit config env prover config_prover driver dataset task = in (prover_answer, additional_info) -let answer_generic limit config prover config_prover driver task = - let task = Driver.prepare_task driver task in - let nn_file = - match Task.on_meta_excl Utils.meta_nn_filename task with - | Some [ MAstr nn_file ] -> Unix.realpath nn_file - | Some _ -> assert false (* By construction of the meta. *) - | None -> invalid_arg "No neural network model found in task" - in - let tasks = - (* We turn [task] into a list (ie, conjunction) of disjunctions of tasks. *) - match prover with - | Prover.Marabou -> Trans.apply Split.split_all task - | Pyrat | Nnenum -> Trans.apply Split.split_premises task - | _ -> [ task ] - in - let command = Whyconf.get_complete_command ~with_steps:false config_prover in - let command = Re__Core.replace_string nnet_or_onnx ~by:nn_file command in +let answer_generic limit config env prover config_prover driver ~proof_strategy + task = + let tasks = proof_strategy env task in let answers = - List.map tasks ~f:(call_prover_on_task limit config command driver) + List.concat_map tasks ~f:(fun task -> + let task = Driver.prepare_task driver task in + let nn_file = + match Task.on_meta_excl Utils.meta_nn_filename task with + | Some [ MAstr nn_file ] -> Unix.realpath nn_file + | Some _ -> assert false (* By construction of the meta. *) + | None -> invalid_arg "No neural network model found in task" + in + let tasks = + (* Turn [task] into a list (ie, conjunction) of disjunctions of + tasks. *) + match prover with + | Prover.Marabou -> Trans.apply Split.split_all task + | Pyrat | Nnenum -> Trans.apply Split.split_premises task + | _ -> [ task ] + in + let command = + Whyconf.get_complete_command ~with_steps:false config_prover + in + let command = Re__Core.replace_string nnet_or_onnx ~by:nn_file command in + List.map tasks ~f:(call_prover_on_task limit config command driver)) in let prover_answer = combine_prover_answers answers in let additional_info = Generic None in (prover_answer, additional_info) -let call_prover ?dataset ~limit config env prover config_prover driver task = +let call_prover ~cwd ~limit config env prover config_prover driver ?dataset task + = let prover_answer, additional_info = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset task @@ -256,34 +263,35 @@ let call_prover ?dataset ~limit config env prover config_prover driver task = let dataset = Unix.realpath (Option.value_exn dataset) in answer_dataset limit config env prover config_prover driver dataset task | Marabou | Pyrat | Nnenum -> - let task = Proof_strategy.apply_native_nn_prover env task in - answer_generic limit config prover config_prover driver task + let task = Interpretation.interpret_task ~cwd env task in + let proof_strategy = Proof_strategy.apply_native_nn_prover in + answer_generic limit config env prover config_prover driver + ~proof_strategy task | CVC5 -> - let task = Proof_strategy.apply_classic_prover env task in - answer_generic limit config prover config_prover driver task + let task = Interpretation.interpret_task ~cwd env task in + let proof_strategy = Proof_strategy.apply_classic_prover in + answer_generic limit config env prover config_prover driver + ~proof_strategy task in let id = Task.task_goal task in { id; prover_answer; additional_info } -let open_file ?(debug = false) ?format ~loadpath file = +let open_file ?format env file = + match file with + | File.Stdin -> + ( Unix.getcwd (), + Env.(read_channel ?format base_language env "stdin" Caml.stdin) ) + | File file -> + let mlw_files, _ = Env.(read_file ?format base_language env file) in + (Caml.Filename.dirname file, mlw_files) + | JSON jin -> + let th = Json.theory_of_input env jin in + (Unix.getcwd () (* TODO *), Wstdlib.Mstr.singleton th.th_name.id_string th) + +let verify ?(debug = false) ?format ~loadpath ?memlimit ?timelimit ?dataset + prover ?prover_altern file = if debug then Debug.set_flag (Debug.lookup_flag "call_prover"); let env, config = create_env ~debug loadpath in - let mstr_theory = - match file with - | File.Stdin -> - Env.(read_channel ?format base_language env "stdin" Caml.stdin) - | File file -> - let mlw_files, _ = Env.(read_file ?format base_language env file) in - mlw_files - | JSON jin -> - let th = Json.theory_of_input env jin in - Wstdlib.Mstr.singleton th.th_name.id_string th - in - (env, config, mstr_theory) - -let verify ?debug ?format ~loadpath ?memlimit ?timelimit ?dataset prover - ?prover_altern file = - let env, config, mstr_theory = open_file ?debug ?format ~loadpath file in let main = Whyconf.get_main config in let limit = let memlimit = Option.value memlimit ~default:(Whyconf.memlimit main) in @@ -338,10 +346,12 @@ let verify ?debug ?format ~loadpath ?memlimit ?timelimit ?dataset prover Driver.load_driver_file_and_extras main env file config_prover.extra_drivers in + let cwd, mstr_theory = open_file ?format env file in Wstdlib.Mstr.map (fun theory -> let tasks = Task.split_theory theory None None in List.map - ~f:(call_prover ?dataset ~limit main env prover config_prover driver) + ~f: + (call_prover ~cwd ~limit main env prover config_prover driver ?dataset) tasks) mstr_theory diff --git a/src/verification.mli b/src/verification.mli index 82a575f331515156abf0d1a7e3445557ce42c6fc..645834243a14eae1befe979ff7e0083f3fc8656a 100644 --- a/src/verification.mli +++ b/src/verification.mli @@ -75,17 +75,5 @@ val verify : for each theory, an [answer] for each goal of the theory, respecting the order of appearance. *) -val open_file : - ?debug:bool -> - ?format:string -> - loadpath:string list -> - File.t -> - Env.env * Whyconf.config * Theory.theory Wstdlib.Mstr.t -(** [open_file ?debug ?format ~loadpath file] just opens the given file. - - @param debug when set, enables debug information. - @param format is the [file] format. - @param loadpath is the additional loadpath. *) - val create_env : ?debug:bool -> string list -> Why3.Env.env * Why3.Whyconf.config diff --git a/stdlib/dune b/stdlib/dune index f1d9fa74d11b355d516bcd7fd324e1b30e3242b1..101848f1f63dc2119534b1b5593450ea8a141f2b 100644 --- a/stdlib/dune +++ b/stdlib/dune @@ -2,5 +2,5 @@ (section (site (caisar stdlib))) - (files caisar.mlw) + (files caisar.mlw interpretation.mlw) (package caisar)) diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw new file mode 100644 index 0000000000000000000000000000000000000000..68a0d7073228d708da194a9710e25d3ffc1548d8 --- /dev/null +++ b/stdlib/interpretation.mlw @@ -0,0 +1,76 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +theory Vector + use int.Int + + type vector 'a + type index = int + + function ([]) (v: vector 'a) (i: index) : 'a + function length (v: vector 'a) : int + function (-) (v1: vector 'a) (v2: vector 'a) : vector 'a + + predicate has_length (v: vector 'a) (i: int) + predicate valid_index (v: vector 'a) (i: index) = 0 <= i < length v + + function mapi (v: vector 'a) (f: index -> 'a -> 'b) : vector 'b + function map (v: vector 'a) (f: 'a -> 'b) : vector 'b + function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c + + predicate forall_ (v: vector 'a) (f: 'a -> bool) = + forall i: index. valid_index v i -> f v[i] + + predicate forall2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> bool) = + length(v1) = length(v2) -> forall i: index. valid_index v1 i -> f v1[i] v2[i] + + function foreach (v: vector 'a) (f: 'a -> 'b) : vector 'b = + map v f + + function foreach2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c = + map2 v1 v2 f +end + +theory NeuralNetwork + use Vector + + type nn + type format = ONNX | NNet + + function read_neural_network (n: string) (f: format) : nn + function (@@) (n: nn) (v: vector 'a) : vector 'a +end + +theory Dataset + use Vector + + type dataset 'a 'b = vector ('a, 'b) + type format = CSV + + function read_dataset (f: string) (k: format) : dataset 'a 'b + + predicate forall_ (d: dataset 'a 'b) (f: 'a -> 'b -> bool) = + Vector.forall_ d (fun e -> let a, b = e in f a b) + + function foreach (d: dataset 'a 'b) (f: 'a -> 'b -> 'c) : vector 'c = + Vector.foreach d (fun e -> let a, b = e in f a b) +end diff --git a/tests/datasets/a/a001.png b/tests/datasets/a/a001.png deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/datasets/a/a002.png b/tests/datasets/a/a002.png deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/dune b/tests/dune index 9429b9b6d4f1f63cafbd57233c7af4501df7db0a..2aed42f48c9297aa4e64908c24103d863736b8af 100644 --- a/tests/dune +++ b/tests/dune @@ -11,8 +11,6 @@ bin/cvc5 bin/nnenum.sh filter_tmpdir.sh - (glob_files "datasets/a/*") - filter_tmpdir.sh ../lib/xgboost/example/california.csv ../lib/xgboost/example/california.json) (package caisar)) diff --git a/tests/interpretation.t b/tests/interpretation.t deleted file mode 100644 index 31dcd1a89e47f5b01d3aa3242820e98199857992..0000000000000000000000000000000000000000 --- a/tests/interpretation.t +++ /dev/null @@ -1,43 +0,0 @@ -Test interpret - $ caisar interpret -L . --format whyml - 2>&1 <<EOF | ./filter_tmpdir.sh - > theory T - > use caisar.Interpretation - > use int.Int - > - > goal G1: 1+1=2 - > - > goal G2: 1+1=3 - > - > goal G3: size (open_dataset "datasets/a") = 2 - > - > goal G4: - > let dataset = open_dataset "datasets/a" in - > size dataset = 2 - > - > predicate robust (i: input) - > - > goal G5: - > let dataset = open_dataset "datasets/a" in - > forall_ dataset (fun i -> robust i) - > - > goal G6: - > let dataset = open_dataset "datasets/a" in - > forall i:int. i=1+(size dataset) -> i < 4 - > end - > EOF - G1 : true - - G2 : false - - G3 : true - caisar_op, - (Interpretation.Dataset { Interpretation.dataset = "datasets/a" }) - G4 : true - caisar_op1, - (Interpretation.Dataset { Interpretation.dataset = "datasets/a" }) - G5 : robust (get 0 caisar_op2) /\ robust (get 1 caisar_op2) - caisar_op2, - (Interpretation.Dataset { Interpretation.dataset = "datasets/a" }) - G6 : forall i:int. i = 3 -> i < 4 - caisar_op3, - (Interpretation.Dataset { Interpretation.dataset = "datasets/a" }) diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t new file mode 100644 index 0000000000000000000000000000000000000000..8591192094f156eba26c4a530922313d44a4557f --- /dev/null +++ b/tests/interpretation_acasxu.t @@ -0,0 +1,99 @@ +Test interpret on acasxu + + $ chmod u+x bin/pyrat.py + + $ bin/pyrat.py --version + PyRAT 1.1 + + $ PATH=$(pwd)/bin:$PATH + + $ caisar verify -L . --format whyml --prover PyRAT - 2>&1 <<EOF | ./filter_tmpdir.sh + > theory T + > use ieee_float.Float64 + > use bool.Bool + > use int.Int + > use interpretation.Vector + > use interpretation.NeuralNetwork + > + > constant nn: nn = read_neural_network "TestNetwork.nnet" NNet + > + > type input = vector t + > + > constant distance_to_intruder: int = 0 + > constant angle_to_intruder: int = 1 + > constant intruder_heading: int = 2 + > constant speed: int = 3 + > constant intruder_speed: int = 4 + > + > type action = int + > + > constant clear_of_conflict: action = 0 + > constant weak_left: action = 1 + > constant weak_right: action = 2 + > constant strong_left: action = 3 + > constant strong_right: action = 4 + > + > constant pi: t = 3.141592653589793115997963468544185161590576171875000 + > + > predicate valid_input (i: input) = + > (0.0:t) .<= i[distance_to_intruder] .<= (60760.0:t) + > /\ .- pi .<= i[angle_to_intruder] .<= pi + > /\ .- pi .<= i[intruder_heading] .<= pi + > /\ (100.0:t) .<= i[speed] .<= (1200.0:t) + > /\ (0.0:t) .<= i[intruder_speed] .<= (1200.0:t) + > + > predicate valid_action (a: action) = clear_of_conflict <= a <= strong_right + > + > predicate advises (n: nn) (i: input) (a: action) = + > valid_action a -> + > forall b: action. + > valid_action b -> a <> b -> (n@@i)[a] .< (n@@i)[b] + > + > predicate intruder_distant_and_slow (i: input) = + > i[distance_to_intruder] .>= (55947.6909999999988940544426441192626953125:t) + > /\ i[speed] .>= (1145.0:t) + > /\ i[intruder_speed] .<= (60.0:t) + > + > function denormalize_t (i: t) (mean: t) (range: t) : t = (i .* range) .+ mean + > + > function denormalize_by_index (idx: int) (t: t) : t = + > if idx = distance_to_intruder then denormalize_t t (19791.0:t) (60261.0:t) + > else if idx = angle_to_intruder then denormalize_t t (0.0:t) (6.2831853071800001231395071954466402530670166015625:t) + > else if idx = intruder_heading then denormalize_t t (0.0:t) (6.2831853071800001231395071954466402530670166015625:t) + > else if idx = speed then denormalize_t t (650.0:t) (1100.0:t) + > else if idx = intruder_speed then denormalize_t t (600.0:t) (1200.0:t) + > else t + > + > function denormalize_input (i:input) : input = + > Vector.mapi i denormalize_by_index + > + > function denormalize_output (o: t) : t = + > denormalize_t o (7.51888402010059753166615337249822914600372314453125:t) (373.9499200000000200816430151462554931640625:t) + > + > goal P1: + > forall i: input. + > has_length i 5 -> + > let j = denormalize_input i in + > valid_input j /\ intruder_distant_and_slow j -> + > let o = (nn@@i)[clear_of_conflict] in + > (denormalize_output o) .<= (1500.0:t) + > + > predicate directly_ahead (i: input) = + > (1500.0:t) .<= i[distance_to_intruder] .<= (1800.0:t) + > /\ .- (0.059999999999999997779553950749686919152736663818359375:t) .<= i[angle_to_intruder] .<= (0.059999999999999997779553950749686919152736663818359375:t) + > + > predicate moving_towards (i: input) = + > i[intruder_heading] .>= (3.100000000000000088817841970012523233890533447265625:t) + > /\ i[speed] .>= (980.0:t) + > /\ i[intruder_speed] .>= (960.0:t) + > + > goal P3: + > forall i: input. + > has_length i 5 -> + > let j = denormalize_input i in + > valid_input j /\ directly_ahead j /\ moving_towards j -> + > not (advises nn i clear_of_conflict) + > end + > EOF + [caisar] Goal P1: Unknown () + [caisar] Goal P3: Unknown () diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t new file mode 100644 index 0000000000000000000000000000000000000000..9996b4544ebb90ec070a0730f800f908427f9a1e --- /dev/null +++ b/tests/interpretation_dataset.t @@ -0,0 +1,56 @@ +Test interpret on dataset + $ cat - > dataset.csv << EOF + > 1,0.0,1.0,0.784313725,0.019607843,0.776470588 + > 0,1.0,0.0,0.019607843,0.776470588,0.784313725 + > EOF + + $ chmod u+x bin/Marabou + + $ bin/Marabou --version + 1.0.+ + + $ PATH=$(pwd)/bin:$PATH + + $ caisar verify -L . --format whyml --prover Marabou - 2>&1 <<EOF | ./filter_tmpdir.sh + > theory T + > use ieee_float.Float64 + > use bool.Bool + > use int.Int + > use interpretation.Vector + > use interpretation.NeuralNetwork + > use interpretation.Dataset + > + > type image = vector t + > type label_ = int + > + > predicate valid_image (i: image) = + > forall v: index. valid_index i v -> (0.0: t) .<= i[v] .<= (1.0: t) + > + > predicate valid_label (l: label_) = 0 <= l <= 2 + > + > predicate advises (n: nn) (i: image) (l: label_) = + > valid_label l -> + > forall j: int. valid_label j -> j <> l -> (n@@i)[l] .> (n@@i)[j] + > + > predicate bounded_by_epsilon (i: image) (eps: t) = + > forall v: index. valid_index i v -> .- eps .<= i[v] .<= eps + > + > predicate robust_around (n: nn) (eps: t) (i: image) (l: label_) = + > forall perturbed_image: image. + > has_length perturbed_image (length i) -> + > valid_image perturbed_image -> + > let perturbation = perturbed_image - i in + > bounded_by_epsilon perturbation eps -> + > advises n perturbed_image l + > + > predicate robust (n: nn) (d: dataset image label_) (eps: t) = + > Dataset.forall_ d (robust_around n eps) + > + > goal G: + > let nn = read_neural_network "TestNetwork.nnet" NNet in + > let dataset = read_dataset "dataset.csv" CSV in + > let eps = (0.375:t) in + > robust nn dataset eps + > end + > EOF + [caisar] Goal G: Unknown () diff --git a/tests/marabou.t b/tests/marabou.t index e8aa1fb53a7b7a4e319a08989fbc537d65858cb3..77ba163cb9b56bbb617b773ef0e6d43e230fd398 100644 --- a/tests/marabou.t +++ b/tests/marabou.t @@ -8,28 +8,37 @@ Test verify $ caisar verify -L . --format whyml --prover=Marabou - 2>&1 <<EOF | ./filter_tmpdir.sh > theory T - > use TestNetwork.AsTuple > use ieee_float.Float64 + > use bool.Bool + > use int.Int + > use interpretation.Vector + > use interpretation.NeuralNetwork > - > goal G: forall x1 x2 x3 x4 x5. - > (0.0:t) .< x1 .< (0.5:t) -> - > let (y1,_,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > (0.0:t) .< y1 .< (0.5:t) + > constant nn: nn = read_neural_network "TestNetwork.nnet" NNet > - > goal H: forall x1 x2 x3 x4 x5. - > (0.0:t) .< x1 .< (0.5:t) /\ (0.5:t) .< x2 .< (1.0:t) -> - > let (y1,y2,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > ((0.0:t) .< y1 \/ (0.5:t) .< y1) /\ (0.0:t) .< y2 .< (0.5:t) + > goal G: + > forall i: vector t. + > has_length i 5 -> + > (0.0:t) .<= i[0] .<= (0.5:t) -> + > (0.0:t) .< (nn@@i)[0] .< (0.5:t) > - > goal I: forall x1 x2 x3 x4 x5. - > (0.0:t) .< x1 .< (0.5:t) /\ (0.5:t) .< x2 .< (1.0:t) -> - > let (y1,y2,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > y2 .< y1 \/ y1 .< y2 + > goal H: + > forall i: vector t. + > has_length i 5 -> + > (0.0:t) .<= i[0] .<= (0.5:t) /\ (0.5:t) .<= i[1] .<= (1.0:t) -> + > ((0.0:t) .< (nn@@i)[0] \/ (0.5:t) .< (nn@@i)[0]) /\ (0.0:t) .< (nn@@i)[1] .< (0.5:t) > - > goal J: forall x1 x2 x3 x4 x5. - > ((0.0:t) .< x1 .< (0.5:t) \/ (0.75:t) .< x1 .< (1.0:t)) /\ (0.5:t) .< x2 .< (1.0:t) -> - > let (y1,y2,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > y2 .< y1 \/ y1 .< y2 + > goal I: + > forall i: vector t. + > has_length i 5 -> + > (0.0:t) .<= i[0] .<= (0.5:t) /\ (0.5:t) .<= i[1] .<= (1.0:t) -> + > (nn@@i)[1] .< (nn@@i)[0] \/ (nn@@i)[0] .< (nn@@i)[1] + > + > goal J: + > forall i: vector t. + > has_length i 5 -> + > ((0.0:t) .<= i[0] .<= (0.5:t) \/ (0.75:t) .<= i[0] .<= (1.0:t)) /\ (0.5:t) .<= i[1] .<= (1.0:t) -> + > (nn@@i)[1] .< (nn@@i)[0] \/ (nn@@i)[0] .< (nn@@i)[1] > end > EOF [caisar] Goal G: Unknown () diff --git a/tests/pyrat.t b/tests/pyrat.t index fa45b4af04c5c377ec4cab0c96b2f622de407906..a59926ecb618a534ee79a8641c9b2308895ea5d7 100644 --- a/tests/pyrat.t +++ b/tests/pyrat.t @@ -8,18 +8,25 @@ Test verify $ caisar verify -L . --format whyml --prover=PyRAT - 2>&1 <<EOF | ./filter_tmpdir.sh > theory T - > use TestNetwork.AsTuple > use ieee_float.Float64 + > use bool.Bool + > use int.Int + > use interpretation.Vector + > use interpretation.NeuralNetwork > - > goal G: forall x1 x2 x3 x4 x5. - > (0.0:t) .< x1 .< (0.5:t) -> - > let (y1,_,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > (0.0:t) .< y1 .< (0.5:t) + > constant nn: nn = read_neural_network "TestNetwork.nnet" NNet > - > goal H: forall x1 x2 x3 x4 x5. - > ((0.0:t) .< x1 .< (0.5:t) \/ (0.375:t) .< x1 .< (0.75:t)) /\ (0.5:t) .< x2 .< (1.0:t) -> - > let (y1,y2,_,_,_) = nn_apply x1 x2 x3 x4 x5 in - > ((0.0:t) .< y1 \/ (0.5:t) .< y1) /\ (0.0:t) .< y2 .< (0.5:t) + > goal G: + > forall i: vector t. + > has_length i 5 -> + > (0.0:t) .<= i[0] .<= (0.5:t) -> + > (0.0:t) .< (nn@@i)[0] .< (0.5:t) + > + > goal H: + > forall i: vector t. + > has_length i 5 -> + > ((0.0:t) .<= i[0] .<= (0.5:t) \/ (0.375:t) .<= i[0] .<= (0.75:t)) /\ (0.5:t) .<= i[1] .<= (1.0:t) -> + > ((0.0:t) .< (nn@@i)[0] \/ (0.5:t) .< (nn@@i)[0]) /\ (0.0:t) .< (nn@@i)[1] .< (0.5:t) > end > EOF [caisar] Goal G: Unknown () diff --git a/tests/pyrat_onnx.t b/tests/pyrat_onnx.t index 919545d027d65cb394160c5f762372024b371dc4..168624e46fc7bc5a52b255cd871d286a4bbe9f1b 100644 --- a/tests/pyrat_onnx.t +++ b/tests/pyrat_onnx.t @@ -8,13 +8,19 @@ Test verify $ caisar verify -L . --format whyml --prover=PyRAT - 2>&1 <<EOF | ./filter_tmpdir.sh > theory T - > use TestNetworkONNX.AsTuple > use ieee_float.Float64 + > use bool.Bool + > use int.Int + > use interpretation.Vector + > use interpretation.NeuralNetwork > - > goal G: forall x1 x2 x3. - > (0.0:t) .< x1 .< (0.5:t) -> - > let (y1,_) = nn_apply x1 x2 x3 in - > (0.0:t) .< y1 .< (0.5:t) + > constant nn: nn = read_neural_network "TestNetworkONNX.onnx" ONNX + > + > goal G: + > forall i: vector t. + > has_length i 3 -> + > (0.0:t) .<= i[0] .<= (0.5:t) -> + > (0.0:t) .< (nn@@i)[0] .< (0.5:t) > end > EOF [caisar] Goal G: Unknown ()