Skip to content
Snippets Groups Projects
Commit 45581f16 authored by Michele Alberti's avatar Michele Alberti
Browse files

Merge branch 'feature/michele/interpret-svm' into 'master'

Interpretation of specs involving SVMs

See merge request laiser/caisar!122
parents c05c8169 da751720
No related branches found
No related tags found
No related merge requests found
Showing
with 1128 additions and 142 deletions
......@@ -125,7 +125,7 @@ exec = "saver"
version_switch = "--version 2>&1 | head -n1 && (which saver > /dev/null 2>&1)"
version_regexp = "\\(v[0-9.]+\\)"
version_ok = "v1.0"
command = "%e %{svm} %{dataset} %{abstraction} %{distance} %{epsilon}"
command = "%e %{svm} %{dataset} %{abstraction} %{perturbation} %{perturbation_param}"
driver = "%{config}/drivers/saver.drv"
use_at_auto_level = 1
......
......@@ -34,8 +34,8 @@ transformation "inline_trivial"
transformation "introduce_premises"
transformation "eliminate_builtin"
transformation "simplify_formula"
transformation "simplify_rel"
transformation "vars_on_lhs"
transformation "simplify_relop"
transformation "vars_on_lhs_of_relop"
theory BuiltIn
syntax type int "int"
......
......@@ -34,7 +34,7 @@ transformation "inline_trivial"
transformation "introduce_premises"
transformation "eliminate_builtin"
transformation "simplify_formula"
transformation "simplify_rel"
transformation "simplify_relop"
theory BuiltIn
syntax type int "int"
......
......@@ -32,8 +32,8 @@ transformation "inline_trivial"
transformation "introduce_premises"
transformation "eliminate_builtin"
transformation "simplify_formula"
transformation "simplify_rel"
transformation "vars_on_lhs"
transformation "simplify_relop"
transformation "vars_on_lhs_of_relop"
theory BuiltIn
syntax type int "Int"
......
......@@ -45,9 +45,8 @@ type eps = float [@@deriving yojson, show]
let string_of_eps eps = Float.to_string eps
let term_of_eps env eps =
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
let ty = Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) [] in
Term.t_const (real_constant_of_float eps) ty
let th = Symbols.Float64.create env in
Term.t_const (real_constant_of_float eps) th.ty
type threshold = float [@@deriving yojson, show]
......@@ -270,8 +269,8 @@ let add_output_decl ~id predicate_kind ~outputs ~record th task =
let add_decls ~kind task =
let n, fid, meta =
match kind with
| `In n -> (n, Fmt.str "x%d", Utils.meta_input)
| `Out n -> (n, Fmt.str "y%d", Utils.meta_output)
| `In n -> (n, Fmt.str "x%d", Meta.meta_input)
| `Out n -> (n, Fmt.str "y%d", Meta.meta_output)
in
let id_lls =
List.init n ~f:(fun id ->
......
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2023 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
module IRE = Interpreter_reduction_engine
module ITypes = Interpreter_types
open Base
let builtin_of_constant env known_map (name, value) =
let decls =
Why3.Ident.Mid.fold_left
(fun acc id ls ->
if String.equal id.Why3.Ident.id_string name then ls :: acc else acc)
[] known_map
in
match decls with
| [] ->
Logging.user_error (fun m ->
m "'%s' is not a declared toplevel constant" name)
| _ :: _ :: _ ->
Logging.user_error (fun m ->
m "'%s' corresponds to multiple declared toplevel constants" name)
| [ { Why3.Decl.d_node = Dparam ls; _ } ] ->
let cst =
match ls.Why3.Term.ls_value with
| None -> (
match value with
| "true" -> Why3.Term.t_true
| "false" -> Why3.Term.t_false
| _ ->
Logging.user_error (fun m ->
m "'%s' expects 'true' or 'false', got '%s' instead" name value))
| Some ty when Why3.Ty.ty_equal ty Why3.Ty.ty_bool -> (
match value with
| "true" -> Why3.Term.t_bool_true
| "false" -> Why3.Term.t_bool_false
| _ ->
Logging.user_error (fun m ->
m "'%s' expects 'true' or 'false', got '%s' instead" name value))
| Some ty
when Why3.Ty.ty_equal ty Why3.Ty.ty_int
|| Why3.Ty.ty_equal ty Why3.Ty.ty_real
|| Why3.Ty.ty_equal ty (Symbols.Float64.create env).ty ->
let lb = Lexing.from_string value in
Why3.Loc.set_file
(Fmt.str "constant '%s' to be bound to '%s'" value name)
lb;
let parsed = Why3.Lexer.parse_term lb in
let cst =
match parsed.term_desc with
| Why3.Ptree.Tconst cst -> cst
| _ ->
Logging.user_error (fun m ->
m "'%s' expects a numerical constant, got '%s' instead" name value)
in
Why3.Term.t_const cst ty
| Some ty when Why3.Ty.ty_equal ty Why3.Ty.ty_str ->
let cst = Why3.Constant.ConstStr value in
Why3.Term.t_const cst ty
| Some ty ->
Logging.not_implemented_yet (fun m ->
m
"'%s' has type '%a' but only toplevel constants of type bool, int, \
real and string can be defined"
name Why3.Pretty.print_ty ty)
in
(ls, fun _ _ _ _ -> IRE.eval_term cst)
| _ ->
Logging.user_error (fun m ->
m "'%s' does not appear to be a declared toplevel constant" name)
let bounded_quant engine vs ~cond : IRE.bounded_quant_result option =
match vs.Why3.Term.vs_ty with
| {
ty_node = Tyapp ({ ts_name = { id_string = "vector"; _ }; _ }, ty :: _);
_;
} -> (
match cond.Why3.Term.t_node with
| Tapp
( { ls_name = { id_string = "has_length"; _ }; _ },
[
({ t_node = Tvar vs1; _ } as _t1);
({ t_node = Tconst (ConstInt n); _ } as _t2);
] ) ->
if not (Why3.Term.vs_equal vs vs1)
then None
else
let n = Why3.Number.to_small_integer n in
let new_quant =
List.init n ~f:(fun _ ->
let preid = Why3.Ident.id_fresh "x" in
Why3.Term.create_vsymbol preid ty)
in
let args = List.map new_quant ~f:(fun vs -> (Why3.Term.t_var vs, ty)) in
let op =
let { ITypes.env; _ } = IRE.user_env engine in
ITypes.Vector (Language.create_vector env n)
in
let substitutions =
[ ITypes.term_of_op ~args engine op (Some vs.vs_ty) ]
in
Some { new_quant; substitutions }
| Tapp ({ ls_name = { id_string = "has_length"; _ }; _ }, _) -> None
| _ ->
Logging.user_error ?loc:vs.vs_name.id_loc (fun m ->
m
"Expecting 'has_length' predicate after universal quantifier on \
vector '%a'"
Why3.Pretty.print_vs vs))
| _ -> None
let declare_language_lsymbols interpreter_env task =
(* Declare [Language] logic symbols only. *)
Why3.Term.Hls.fold
(fun ls _ task ->
(* Add meta corresponding to logic symbol. *)
let task = Language.add_meta_nn task ls in
let task = Language.add_meta_svm task ls in
let task = Language.add_meta_dataset_csv task ls in
(* Add actual logic symbol declaration. *)
let decl = Why3.Decl.create_param_decl ls in
Why3.Task.add_decl task decl)
interpreter_env.ITypes.op_of_ls task
let interpret_task ~cwd ?(definitions = []) env task =
let known_map = Why3.Task.task_known task in
let interpreter_env = ITypes.interpreter_env ~cwd env in
let params =
{
IRE.compute_defs = true;
compute_builtin = true;
compute_def_set = Why3.Term.Sls.empty;
compute_max_quantifier_domain = Int.max_value;
}
in
let builtins = List.map ~f:(builtin_of_constant env known_map) definitions in
let engine =
IRE.create ~bounded_quant ~builtins params env known_map interpreter_env
Interpreter_theory.builtins
in
let g, f = (Why3.Task.task_goal task, Why3.Task.task_goal_fmla task) in
let f = IRE.normalize ~limit:Int.max_value engine Why3.Term.Mvs.empty f in
Logs.debug ~src:Logging.src_interpret_goal (fun m ->
m "Interpreted formula for goal '%a':@.%a@.%a" Why3.Pretty.print_pr g
Why3.Pretty.print_term f ITypes.pp_interpreter_op_hls
interpreter_env.op_of_ls);
let _, task = Why3.Task.task_separate_goal task in
let task = declare_language_lsymbols interpreter_env task in
let task = Why3.Task.(add_prop_decl task Pgoal g f) in
task
......@@ -20,11 +20,9 @@
(* *)
(**************************************************************************)
open Why3
val interpret_task :
cwd:string ->
?def_constants:(string * string) list ->
Env.env ->
Task.task ->
Task.task
?definitions:(string * string) list ->
Why3.Env.env ->
Why3.Task.task ->
Why3.Task.task
......@@ -70,6 +70,11 @@ let user_env x = x.user_env
let v_attr_copy orig v =
match v with Int _ -> v | Real _ -> v | Term t -> Term (t_attr_copy orig t)
let value_term t = Value (Term t)
let value_int i = Value (Int i)
let value_real r = Value (Real r)
let eval_term t = Eval t
let term_of_value v =
let open Number in
match v with
......@@ -106,6 +111,10 @@ let real_of_value v =
| Term { t_node = Tconst c } -> real_of_const c
| Term _ -> raise NotNum
let reconstruct_term () =
(* Force the engine to reconstruct the original term. *)
raise Stdlib.Not_found
(* {2 Builtin symbols} *)
(* all builtin functions *)
......
......@@ -71,19 +71,16 @@ open Why3
type 'a engine
(** abstract type for reduction engines *)
val user_env : 'a engine -> 'a
type params = {
compute_defs : bool;
compute_builtin : bool;
(** When set to true, automatically compute symbols using known definitions.
Otherwise, only symbols in [compute_def_set] will be computed. *)
compute_builtin : bool; (** When set to true, compute builtin functions. *)
compute_def_set : Term.Sls.t;
compute_max_quantifier_domain : int;
(** Maximum domain size for the reduction of bounded quantifications. *)
}
(** Configuration of the engine. . [compute_defs]: if set to true, automatically
compute symbols using known definitions. Otherwise, only symbols in
[compute_def_set] will be computed. . [compute_builtin]: if set to true,
compute builtin functions. . [compute_max_quantifier_domain]: maximum domain
size for the reduction of bounded quantifications *)
(** Configuration of the engine. *)
type value =
| Term of Why3.Term.term (* invariant: is in normal form *)
......@@ -113,6 +110,11 @@ type bounded_quant_result = {
type 'a bounded_quant =
'a engine -> Term.vsymbol -> cond:Term.term -> bounded_quant_result option
val value_term : Term.term -> builtin_value
val value_int : BigInt.t -> builtin_value
val value_real : Number.real_value -> builtin_value
val eval_term : Term.term -> builtin_value
val create :
?bounded_quant:'a bounded_quant ->
?builtins:(Why3.Term.lsymbol * 'a builtin) list ->
......@@ -122,17 +124,23 @@ val create :
'a ->
'a built_in_theories list ->
'a engine
(** [create env known_map] creates a reduction engine with . builtins theories
(int.Int, etc.) extracted from [env] . known declarations from [known_map] .
empty set of rewrite rules *)
(** [create env known_map] creates a reduction engine with: builtins theories
(int.Int, etc.) extracted from [env], known declarations from [known_map],
and empty set of rewrite rules. *)
val user_env : 'a engine -> 'a
val reconstruct_term : unit -> 'a
(** Force the engine to reconstruct the original term. *)
exception NotARewriteRule of string
val add_rule : Term.term -> 'a engine -> 'a engine
(** [add_rule t e] turns [t] into a new rewrite rule and returns the new engine.
raise NotARewriteRule if [t] cannot be seen as a rewrite rule according to
the general rules given above. *)
@raise [NotARewriteRule]
if [t] cannot be seen as a rewrite rule according to the general rules
given above. *)
val normalize :
?step_limit:int ->
......@@ -144,20 +152,25 @@ val normalize :
(** [normalize e sigma t] normalizes the term [t] with respect to the engine [e]
with an initial variable environment [sigma].
parameter [limit] provides a maximum number of steps for execution. When
limit is reached, the partially reduced term is returned. parameter
[step_limit] provides a maximum number of steps on reductions that would
change the term even after reconstruction. *)
@param [limit]
provides a maximum number of steps for execution. When limit is reached,
the partially reduced term is returned.
@param [step_limit]
provides a maximum number of steps on reductions that would change the
term even after reconstruction. *)
open Term
exception NoMatch of (Term.term * Term.term * Term.term option) option
(** [NoMatch (t1, t2, t3)]
exception NoMatch of (term * term * term option) option
(** [NoMatch (t1, t2, t3)] Cannot match [t1] with [t2]. If [t3] exists then [t1]
is already matched with [t3]. *)
Cannot match [t1] with [t2]. If [t3] exists then [t1] is already matched
with [t3]. *)
exception NoMatchpat of (pattern * pattern) option
exception NoMatchpat of (Term.pattern * Term.pattern) option
type substitution = term Mvs.t
type substitution = Term.term Term.Mvs.t
val first_order_matching :
Svs.t -> term list -> term list -> Ty.ty Ty.Mtv.t * substitution
Term.Svs.t ->
Term.term list ->
Term.term list ->
Ty.ty Ty.Mtv.t * substitution
......@@ -20,5 +20,7 @@
(* *)
(**************************************************************************)
val init : unit -> unit
(** Register the transformation. *)
module IRE = Interpreter_reduction_engine
module ITypes = Interpreter_types
val builtins : ITypes.interpreter_env IRE.built_in_theories list
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2023 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
module IRE = Interpreter_reduction_engine
open Base
type nn =
| NNet
| ONNX
[@@deriving show]
type model =
| NN of Why3.Term.lsymbol * nn
[@printer
fun fmt (ls, nn) ->
Fmt.pf fmt "(%a, %a)" pp_nn nn
Fmt.(option Language.pp_nn)
(Language.lookup_nn ls)]
| SVM of Why3.Term.lsymbol
[@printer
fun fmt ls ->
Fmt.pf fmt "%a" Fmt.(option Language.pp_svm) (Language.lookup_svm ls)]
[@@deriving show]
type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
[@@deriving show]
type data = D_csv of string list [@@deriving show]
type interpreter_op =
| Model of model
| Dataset of Why3.Term.lsymbol * dataset
[@printer
fun fmt (ls, ds) ->
Fmt.pf fmt "(%a, %a)"
Fmt.(option Language.pp_dataset)
(Language.lookup_dataset_csv ls)
pp_dataset ds]
| Data of data
| Vector of Why3.Term.lsymbol
[@printer
fun fmt v -> Fmt.pf fmt "%a" Fmt.(option int) (Language.lookup_vector v)]
[@@deriving show]
type interpreter_op_hls =
(interpreter_op Why3.Term.Hls.t
[@printer
fun fmt hls ->
Why3.(
Pp.print_iter2 Term.Hls.iter Pp.newline Pp.comma Pretty.print_ls
pp_interpreter_op fmt hls)])
[@@deriving show]
type ls_htbl_interpreter_op = (interpreter_op, Why3.Term.lsymbol) Base.Hashtbl.t
type interpreter_env = {
op_of_ls : interpreter_op_hls;
ls_of_op : ls_htbl_interpreter_op;
env : Why3.Env.env;
cwd : string;
}
let ls_of_op engine interpreter_op ty_args ty =
let interpreter_env = IRE.user_env engine in
Hashtbl.find_or_add interpreter_env.ls_of_op interpreter_op
~default:(fun () ->
let id = Why3.Ident.id_fresh "interpreter_op" in
let ls =
match interpreter_op with
| Model (NN (m, _) | SVM m) -> m
| Vector v -> v
| Dataset (d, _) -> d
| _ -> Why3.Term.create_lsymbol id ty_args ty
in
Hashtbl.Poly.add_exn interpreter_env.ls_of_op ~key:interpreter_op ~data:ls;
Why3.Term.Hls.add interpreter_env.op_of_ls ls interpreter_op;
ls)
let op_of_ls engine ls =
let interpreter_env = IRE.user_env engine in
Why3.Term.Hls.find interpreter_env.op_of_ls ls
let term_of_op ?(args = []) engine interpreter_op ty =
let t_args, ty_args = List.unzip args in
Why3.Term.t_app_infer (ls_of_op engine interpreter_op ty_args ty) t_args
let interpreter_env ~cwd env =
{
ls_of_op = Hashtbl.Poly.create ();
op_of_ls = Why3.Term.Hls.create 10;
env;
cwd;
}
......@@ -20,76 +20,45 @@
(* *)
(**************************************************************************)
open Why3
open Base
module IRE = Interpreter_reduction_engine
let make_rt env =
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
let t = Theory.ns_find_ts th.th_export [ "t" ] in
let le_float = Theory.ns_find_ls th.th_export [ "le" ] in
let lt_float = Theory.ns_find_ls th.th_export [ "lt" ] in
let ge_float = Theory.ns_find_ls th.th_export [ "ge" ] in
let gt_float = Theory.ns_find_ls th.th_export [ "gt" ] in
let th = Env.read_theory env [ "real" ] "Real" in
let le_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in
let lt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in
let ge_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in
let gt_real = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in
let rec rt t =
let t = Term.t_map rt t in
match t.t_node with
| Tapp
( ls,
[
({ t_node = Tconst _; _ } as const);
({ t_node = Tapp (_, []); _ } as var);
] ) ->
let tt = [ var; const ] in
let ls_rel =
if Term.ls_equal ls le_float
then ge_float
else if Term.ls_equal ls ge_float
then le_float
else if Term.ls_equal ls lt_float
then gt_float
else if Term.ls_equal ls gt_float
then lt_float
else ls
in
let ls_rel =
if Term.ls_equal ls le_real
then ge_real
else if Term.ls_equal ls ge_real
then le_real
else if Term.ls_equal ls lt_real
then gt_real
else if Term.ls_equal ls gt_real
then lt_real
else ls_rel
in
if Term.ls_equal ls_rel ls then t else Term.ps_app ls_rel tt
| _ -> t
in
let task =
List.fold
[ le_float; lt_float; ge_float; gt_float ]
~init:(Task.add_ty_decl None t) ~f:Task.add_param_decl
in
let task =
List.fold
[ le_real; lt_real; ge_real; gt_real ]
~init:(Task.add_ty_decl task Ty.ts_real)
~f:Task.add_param_decl
in
(rt, task)
type nn =
| NNet
| ONNX
[@@deriving show]
let vars_on_lhs env =
let rt, task = make_rt env in
Trans.rewrite rt task
type model =
| NN of Why3.Term.lsymbol * nn
| SVM of Why3.Term.lsymbol
[@@deriving show]
let init () =
Trans.register_env_transform
~desc:
"Transformation for provers that need variables on the left-hand-side of \
logic symbols."
"vars_on_lhs" vars_on_lhs
type dataset = DS_csv of Csv.t [@@deriving show]
type data = D_csv of string list [@@deriving show]
type interpreter_op =
| Model of model
| Dataset of Why3.Term.lsymbol * dataset
| Data of data
| Vector of Why3.Term.lsymbol
[@@deriving show]
type interpreter_op_hls = interpreter_op Why3.Term.Hls.t [@@deriving show]
type ls_htbl_interpreter_op = (interpreter_op, Why3.Term.lsymbol) Base.Hashtbl.t
type interpreter_env = private {
op_of_ls : interpreter_op_hls;
ls_of_op : ls_htbl_interpreter_op;
env : Why3.Env.env;
cwd : string;
}
val op_of_ls : interpreter_env IRE.engine -> Why3.Term.lsymbol -> interpreter_op
val term_of_op :
?args:(Why3.Term.term * Why3.Ty.ty) list ->
interpreter_env IRE.engine ->
interpreter_op ->
Why3.Ty.ty option ->
Why3.Term.term
val interpreter_env : cwd:string -> Why3.Env.env -> interpreter_env
......@@ -160,13 +160,15 @@ let register_ovo_support () =
Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
(fun env _ filename _ -> ovo_parser env filename)
(* -- Vectors *)
(* -------------------------------------------------------------------------- *)
(* --- Vectors *)
(* -------------------------------------------------------------------------- *)
let vectors = Term.Hls.create 10
let ty_float64_t env =
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
let th = Symbols.Float64.create env in
th.ty
let ty_vector env ty_elt =
let th = Env.read_theory env [ "caisar"; "types" ] "Vector" in
......@@ -188,17 +190,24 @@ let create_vector =
let lookup_vector = Term.Hls.find_opt vectors
let mem_vector = Term.Hls.mem vectors
(* -- Neural Networks *)
(* -------------------------------------------------------------------------- *)
(* --- Neural Networks *)
(* -------------------------------------------------------------------------- *)
type nn = {
nn_nb_inputs : int;
nn_nb_outputs : int;
nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
nn_filename : string;
nn_nier : Onnx.G.t option; [@opaque]
nn_format : nn_format;
}
[@@deriving show]
and nn_format =
| NNet
| ONNX of Onnx.G.t option [@printer fun fmt _ -> Fmt.pf fmt "<nier>"]
[@@deriving show]
let nets = Term.Hls.create 10
let fresh_nn_ls env name =
......@@ -213,63 +222,146 @@ let fresh_nn_ls env name =
let id = Ident.id_fresh name in
Term.create_fsymbol id [] ty_model
let create_nn_nnet =
let create_nn_nnet env filename =
let model = Nnet.parse ~permissive:true filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } ->
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_float64_t env;
nn_filename = filename;
nn_format = NNet;
}
let create_nn_onnx env filename =
let model = Onnx.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; nier } ->
let nier =
match nier with
| Error msg ->
Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg);
None
| Ok nier -> Some nier
in
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_float64_t env;
nn_filename = filename;
nn_format = ONNX nier;
}
let create_nn =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
let ty_elt = ty_float64_t env in
Hashtbl.findi_or_add h ~default:(fun filename ->
let ls = fresh_nn_ls env "nnet_nn" in
let nn =
let model = Nnet.parse ~permissive:true filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } ->
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = None;
}
let h = Hashtbl.Poly.create () in
Hashtbl.Poly.findi_or_add h ~default:(fun format ->
let format_nn, create_nn =
match format with
| `NNet -> ("nnet", create_nn_nnet)
| `ONNX -> ("onnx", create_nn_onnx)
in
Term.Hls.add nets ls nn;
ls))
let name_nn = "nn_" ^ format_nn in
let h = Hashtbl.create (module String) in
Hashtbl.findi_or_add h ~default:(fun filename ->
let ls = fresh_nn_ls env name_nn in
let nn = create_nn env filename in
Term.Hls.add nets ls nn;
ls)))
let create_nn_onnx =
Env.Wenv.memoize 13 (fun env ->
let lookup_nn = Term.Hls.find_opt nets
let mem_nn = Term.Hls.mem nets
let iter_nn f = Term.Hls.iter f nets
let add_meta_nn task ls =
match lookup_nn ls with
| None -> task
| Some { nn_filename; _ } ->
Task.add_meta task Meta.meta_nn_filename [ MAstr nn_filename ]
(* -------------------------------------------------------------------------- *)
(* --- Support Vector Machines (SVM) *)
(* -------------------------------------------------------------------------- *)
type svm = {
svm_nb_inputs : int;
svm_nb_outputs : int;
svm_abstraction : svm_abstraction;
svm_filename : string;
}
[@@deriving show]
and svm_abstraction =
| Interval
| Raf
| Hybrid
[@@deriving show]
let string_of_svm_abstraction = function
| Interval -> "interval"
| Raf -> "raf"
| Hybrid -> "hybrid"
let svm_abstraction_of_string s =
match String.lowercase s with
| "interval" -> Some Interval
| "raf" -> Some Raf
| "hybrid" -> Some Hybrid
| _ -> None
let svms = Term.Hls.create 10
let fresh_svm_ls env name =
let ty_kind =
let th = Env.read_theory env [ "caisar"; "model" ] "SVM" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "svm" ]) []
in
let ty_model =
let th = Env.read_theory env [ "caisar"; "model" ] "Model" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "model" ]) [ ty_kind ]
in
let id = Ident.id_fresh name in
Term.create_fsymbol id [] ty_model
let create_svm =
Env.Wenv.memoize 13 (fun env ?(abstraction = Hybrid) ->
let h = Hashtbl.create (module String) in
let ty_elt = ty_float64_t env in
Hashtbl.findi_or_add h ~default:(fun filename ->
let ls = fresh_nn_ls env "onnx_nn" in
let onnx =
let model = Onnx.parse filename in
let name = "svm_ovo_" ^ string_of_svm_abstraction abstraction in
let ls = fresh_svm_ls env name in
let svm =
let model = Ovo.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; nier } ->
let nier =
match nier with
| Error msg ->
Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg);
None
| Ok nier -> Some nier
in
| Ok { n_inputs; n_outputs } ->
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = nier;
svm_nb_inputs = n_inputs;
svm_nb_outputs = n_outputs;
svm_abstraction = abstraction;
svm_filename = filename;
}
in
Term.Hls.add nets ls onnx;
Term.Hls.add svms ls svm;
ls))
let lookup_nn = Term.Hls.find_opt nets
let mem_nn = Term.Hls.mem nets
let iter_nn f = Term.Hls.iter f nets
let lookup_svm = Term.Hls.find_opt svms
let mem_svm = Term.Hls.mem svms
let iter_svm f = Term.Hls.iter f svms
(* -- Datasets *)
let add_meta_svm task ls =
match lookup_svm ls with
| None -> task
| Some { svm_filename; svm_abstraction; _ } ->
Task.add_meta task Meta.meta_svm_filename
[ MAstr svm_filename; MAstr (string_of_svm_abstraction svm_abstraction) ]
(* -------------------------------------------------------------------------- *)
(* --- Datasets *)
(* -------------------------------------------------------------------------- *)
type dataset = CSV of string [@@deriving show]
......@@ -294,3 +386,9 @@ let create_dataset_csv =
let lookup_dataset_csv = Term.Hls.find_opt datasets
let mem_dataset_csv = Term.Hls.mem datasets
let add_meta_dataset_csv task ls =
match lookup_dataset_csv ls with
| None -> task
| Some (CSV filename) ->
Task.add_meta task Meta.meta_dataset_filename [ MAstr filename ]
......@@ -76,15 +76,47 @@ type nn = private {
nn_nb_outputs : int;
nn_ty_elt : Ty.ty;
nn_filename : string;
nn_nier : Onnx.G.t option;
nn_format : nn_format;
}
[@@deriving show]
val create_nn_nnet : Env.env -> string -> Term.lsymbol
val create_nn_onnx : Env.env -> string -> Term.lsymbol
and nn_format =
| NNet
| ONNX of Onnx.G.t option
[@@deriving show]
val create_nn : Env.env -> [ `NNet | `ONNX ] -> string -> Term.lsymbol
val lookup_nn : Term.lsymbol -> nn option
val mem_nn : Term.lsymbol -> bool
val iter_nn : (Term.lsymbol -> nn -> unit) -> unit
val add_meta_nn : Task.task -> Term.lsymbol -> Task.task
(** {2 Support Vector Machines (SVM)} *)
type svm = private {
svm_nb_inputs : int;
svm_nb_outputs : int;
svm_abstraction : svm_abstraction;
svm_filename : string;
}
[@@deriving show]
and svm_abstraction =
| Interval
| Raf
| Hybrid
[@@deriving show]
val string_of_svm_abstraction : svm_abstraction -> string
val svm_abstraction_of_string : string -> svm_abstraction option
val create_svm :
Env.env -> ?abstraction:svm_abstraction -> string -> Term.lsymbol
val lookup_svm : Term.lsymbol -> svm option
val mem_svm : Term.lsymbol -> bool
val iter_svm : (Term.lsymbol -> svm -> unit) -> unit
val add_meta_svm : Task.task -> Term.lsymbol -> Task.task
(** {2 Datasets} *)
......@@ -93,3 +125,4 @@ type dataset = private CSV of string [@@deriving show]
val create_dataset_csv : Env.env -> string -> Term.lsymbol
val lookup_dataset_csv : Term.lsymbol -> dataset option
val mem_dataset_csv : Term.lsymbol -> bool
val add_meta_dataset_csv : Task.task -> Term.lsymbol -> Task.task
......@@ -26,15 +26,17 @@ open Cmdliner
let caisar = "caisar"
let () =
Simplify_rel.init ();
Vars_on_lhs.init ()
Relop.register_simplify_relop ();
Relop.register_vars_on_lhs_of_relop ()
let () =
Pyrat.init ();
Marabou.init ();
Vnnlib.init ()
(* -- Logs *)
(* -------------------------------------------------------------------------- *)
(* --- Logs *)
(* -------------------------------------------------------------------------- *)
let log_tags =
let all_tags = Logging.all_srcs () in
......@@ -58,7 +60,9 @@ let setup_logs =
const Logging.setup $ Fmt_cli.style_renderer () $ Logs_cli.level ()
$ log_tags)
(* -- Commands *)
(* -------------------------------------------------------------------------- *)
(* --- Commands *)
(* -------------------------------------------------------------------------- *)
let config detect () =
if detect
......@@ -89,7 +93,7 @@ let memlimit_of_string s =
| [ v ], ([] | [ "M" ] | [ "MB" ]) -> Int.of_string v
| [ v ], ([ "G" ] | [ "GB" ]) -> Int.of_string v * 1000
| [ v ], ([ "T" ] | [ "TB" ]) -> Int.of_string v * 1000000
| _ -> invalid_arg "Unrecognized memory limit"
| _ -> Logging.user_error (fun m -> m "Unrecognized memory limit")
let timelimit_of_string s =
let s = String.strip s in
......@@ -105,7 +109,7 @@ let timelimit_of_string s =
| [ v ], ([] | [ "s" ]) -> Int.of_string v
| [ v ], [ "m" ] -> Int.of_string v * 60
| [ v ], [ "h" ] -> Int.of_string v * 3600
| _ -> invalid_arg "Unrecognized time limit"
| _ -> Logging.user_error (fun m -> m "Unrecognized time limit")
let log_theory_answer =
Why3.Wstdlib.Mstr.iter (fun theory_name task_answers ->
......@@ -123,14 +127,14 @@ let log_theory_answer =
additional_info)))
let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
?def_constants ?theories ?goals ?onnx_out_dir files =
?definitions ?theories ?goals ?onnx_out_dir files =
let memlimit = Option.map memlimit ~f:memlimit_of_string in
let timelimit = Option.map timelimit ~f:timelimit_of_string in
let theory_answers =
List.map files
~f:
(Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset
prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_dir)
prover ?prover_altern ?definitions ?theories ?goals ?onnx_out_dir)
in
List.iter theory_answers ~f:log_theory_answer;
theory_answers
......@@ -181,14 +185,16 @@ let verify_json ?memlimit ?timelimit ?outfile json =
let infile = Result.ok_or_failwith (Verification.File.of_json_input jin) in
let verification_results =
verify ~loadpath:[] ?memlimit ?timelimit ~dataset:jin.property.dataset
jin.prover ~def_constants:[] [ infile ]
jin.prover ~definitions:[] [ infile ]
in
match verification_results with
| [] -> assert false (* We always build one theory from the provided JSON. *)
| [ verification_result ] ->
Option.iter outfile
~f:(record_verification_result jin.id verification_result)
| _ -> failwith "Unexpected more than one theory from a JSON file"
| _ ->
Logging.user_error (fun m ->
m "Unexpected more than one theory from a JSON file")
let verify_xgboost ?memlimit ?timelimit xgboost dataset prover =
let memlimit = Option.map memlimit ~f:memlimit_of_string in
......@@ -199,7 +205,9 @@ let exec_cmd cmdname cmd =
Logs.debug (fun m -> m "Execution of command '%s'" cmdname);
cmd ()
(* -- Command line. *)
(* -------------------------------------------------------------------------- *)
(* --- Command line *)
(* -------------------------------------------------------------------------- *)
let memlimit =
let doc =
......@@ -263,8 +271,8 @@ let verify_cmd =
& opt (some string) None
& info [ "onnx-out-dir" ] ~doc ~docv:"DIRECTORY")
in
let define_constants =
let doc = "Define a declared constant with the given value." in
let definitions =
let doc = "Define a toplevel constant declaration with the given value." in
Arg.(
value
& opt_all (pair ~sep:':' string string) []
......@@ -302,17 +310,17 @@ let verify_cmd =
in
let verify_term =
let verify format loadpath memlimit timelimit prover prover_altern dataset
def_constants theories goals onnx_out_dir files () =
definitions theories goals onnx_out_dir files () =
ignore
(verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover
?prover_altern ~def_constants ~theories ~goals ?onnx_out_dir files)
?prover_altern ~definitions ~theories ~goals ?onnx_out_dir files)
in
Term.(
const (fun _ -> exec_cmd cmdname)
$ setup_logs
$ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover
$ prover_altern $ dataset $ define_constants $ theories $ goals
$ onnx_out_dir $ files))
$ prover_altern $ dataset $ definitions $ theories $ goals $ onnx_out_dir
$ files))
in
Cmd.v info verify_term
......
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2023 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
let meta_input =
Why3.Theory.(
register_meta "caisar_input"
~desc:"Indicates an input position among the inputs of the neural network"
[ MTlsymbol; MTint ])
let meta_output =
Why3.Theory.(
register_meta "caisar_output"
~desc:
"Indicates an output position among the outputs of the neural network"
[ MTlsymbol; MTint ])
let meta_nn_filename =
Why3.Theory.(
register_meta_excl "caisar_nnet_or_onnx"
~desc:"Indicates the filename of the neural network" [ MTstring ])
let meta_svm_filename =
Why3.Theory.(
register_meta_excl "caisar_svm"
~desc:
"Indicates the filename and abstraction of the support vector machine"
[ MTstring; MTstring ])
let meta_dataset_filename =
Why3.Theory.(
register_meta_excl "caisar_dataset"
~desc:"Indicates the filename of the dataset" [ MTstring ])
......@@ -20,5 +20,17 @@
(* *)
(**************************************************************************)
val init : unit -> unit
(** Register the transformation. *)
val meta_input : Why3.Theory.meta
(** Indicates an input position among the inputs of the neural network. *)
val meta_output : Why3.Theory.meta
(** Indicates an output position among the outputs of the neural network. *)
val meta_nn_filename : Why3.Theory.meta
(** The filename of the neural network. *)
val meta_svm_filename : Why3.Theory.meta
(** The filename and abstraction of the support vector machine. *)
val meta_dataset_filename : Why3.Theory.meta
(** The filename of the dataset. *)
......@@ -169,11 +169,11 @@ let rec print_tdecl info fmt task =
print_tdecl info fmt task_prev;
match task_decl.Theory.td_node with
| Use _ | Clone _ -> ()
| Meta (meta, l) when Theory.meta_equal meta Utils.meta_input -> (
| Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
| _ -> assert false)
| Meta (meta, l) when Theory.meta_equal meta Utils.meta_output -> (
| Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
| _ -> assert false)
......@@ -182,20 +182,12 @@ let rec print_tdecl info fmt task =
let print_task args ?old:_ fmt task =
let ls_rel_real =
let th = Env.read_theory args.Printer.env [ "real" ] "Real" in
let le = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in
let lt = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in
let ge = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in
let gt = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in
{ le; ge; lt; gt }
let th = Symbols.Real.create args.Printer.env in
{ le = th.le; ge = th.ge; lt = th.lt; gt = th.gt }
in
let ls_rel_float =
let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in
let le = Theory.ns_find_ls th.th_export [ "le" ] in
let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
{ le; ge; lt; gt }
let th = Symbols.Float64.create args.Printer.env in
{ le = th.le; ge = th.ge; lt = th.lt; gt = th.gt }
in
let info =
{
......
......@@ -155,11 +155,11 @@ let rec print_tdecl info fmt task =
print_tdecl info fmt task_prev;
match task_decl.Theory.td_node with
| Use _ | Clone _ -> ()
| Meta (meta, l) when Theory.meta_equal meta Utils.meta_input -> (
| Meta (meta, l) when Theory.meta_equal meta Meta.meta_input -> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i)
| _ -> assert false)
| Meta (meta, l) when Theory.meta_equal meta Utils.meta_output -> (
| Meta (meta, l) when Theory.meta_equal meta Meta.meta_output -> (
match l with
| [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i)
| _ -> assert false)
......@@ -168,20 +168,12 @@ let rec print_tdecl info fmt task =
let print_task args ?old:_ fmt task =
let ls_rel_real =
let th = Env.read_theory args.Printer.env [ "real" ] "Real" in
let le = Theory.ns_find_ls th.th_export [ Ident.op_infix "<=" ] in
let lt = Theory.ns_find_ls th.th_export [ Ident.op_infix "<" ] in
let ge = Theory.ns_find_ls th.th_export [ Ident.op_infix ">=" ] in
let gt = Theory.ns_find_ls th.th_export [ Ident.op_infix ">" ] in
{ le; ge; lt; gt }
let th = Symbols.Real.create args.Printer.env in
{ le = th.le; ge = th.ge; lt = th.lt; gt = th.gt }
in
let ls_rel_float =
let th = Env.read_theory args.Printer.env [ "ieee_float" ] "Float64" in
let le = Theory.ns_find_ls th.th_export [ "le" ] in
let lt = Theory.ns_find_ls th.th_export [ "lt" ] in
let ge = Theory.ns_find_ls th.th_export [ "ge" ] in
let gt = Theory.ns_find_ls th.th_export [ "gt" ] in
{ le; ge; lt; gt }
let th = Symbols.Float64.create args.Printer.env in
{ le = th.le; ge = th.ge; lt = th.lt; gt = th.gt }
in
let info =
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment