diff --git a/lib/xgboost/bin/parse_xgboost.ml b/lib/xgboost/bin/parse_xgboost.ml index 5533b8e5c915df021bc59f479c187a856ccb3e01..4a1b995e3e1a9b40df21cd78844e863193b5788d 100644 --- a/lib/xgboost/bin/parse_xgboost.ml +++ b/lib/xgboost/bin/parse_xgboost.ml @@ -4,8 +4,13 @@ let predict xg filename = let inputs = Caisar_xgboost.Input.of_filename xg filename in List.iter (fun features -> - let sum = Caisar_xgboost.Predict.predict xg features in - Format.printf "%0.6f\n" sum) + let sum1 = Caisar_xgboost.Predict.predict xg features in + let sum2 = + Caisar_xgboost.Tree.predict (Caisar_xgboost.Tree.convert xg) features + in + if sum1 = sum2 + then Format.printf "%0.6f\n" sum1 + else Format.printf "ERROR: %0.6f <> %0.6f\n" sum1 sum2) inputs let () = diff --git a/lib/xgboost/predict.mli b/lib/xgboost/predict.mli index 2b4f3908d5c4b5db624d9024fd24844e99150764..fc65e8dc23feeb4641fef306fb760d133976b5cc 100644 --- a/lib/xgboost/predict.mli +++ b/lib/xgboost/predict.mli @@ -20,4 +20,5 @@ (* *) (**************************************************************************) +val sigmoid : Float.t -> Float.t val predict : Parser.t -> Input.t -> float diff --git a/lib/xgboost/tree.ml b/lib/xgboost/tree.ml new file mode 100644 index 0000000000000000000000000000000000000000..88451ffa884624653f0f88fc8534401ae0ecdf18 --- /dev/null +++ b/lib/xgboost/tree.ml @@ -0,0 +1,104 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +type tree = + | Split of { + split_indice : int; + split_condition : float; + left : tree; + right : tree; + missing : [ `Left ]; + } + | Leaf of { leaf_value : float } + +type op = + | Identity + | Sigmoid + +type t = { + base_score : float; + trees : tree array; + after_sum : op; +} +(** the value is [op(base_score + sum(tree))] *) + +let predict t input = + let rec aux input = function + | Split s -> ( + match Input.get input s.split_indice with + | None -> + aux input s.left (* TODO: check if missing can be on the right *) + | Some v when v < s.split_condition -> aux input s.left + | _ -> aux input s.right) + | Leaf l -> l.leaf_value + in + let sum = + Array.fold_left (fun acc t -> acc +. aux input t) t.base_score t.trees + in + match t.after_sum with Identity -> sum | Sigmoid -> Predict.sigmoid sum + +let convert_tree (t : Parser.tree) : tree = + let rec aux node = + assert (-1 <= t.left_children.(node)); + if t.left_children.(node) = -1 + then Leaf { leaf_value = t.split_conditions.(node) } + else + Split + { + split_indice = t.split_indices.(node); + split_condition = t.split_conditions.(node); + left = aux t.left_children.(node); + right = aux t.right_children.(node); + missing = `Left; + } + in + aux 0 + +let convert_trees (t : Parser.t) (gb : Parser.gbtree) : t = + let base_score = + let base_score = Float.of_string t.learner.learner_model_param.base_score in + match t.learner.objective with + | Parser.Reg_squarederror _ -> base_score + | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" + | Parser.Reg_squaredlogerror _ -> base_score (* ? *) + | Parser.Reg_linear _ -> base_score (* ? *) + | Parser.Binary_logistic _ -> 0. + in + + let trees = Array.map convert_tree gb.trees in + (* From regression_loss.h PredTransform *) + let after_sum = + match t.learner.objective with + | Parser.Reg_squarederror _ -> Identity + | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" + | Parser.Reg_squaredlogerror _ -> Identity + | Parser.Reg_linear _ -> Identity + | Parser.Binary_logistic _ -> Sigmoid + in + (* Format.eprintf "%f -> %f@." sum pred; *) + { base_score; trees; after_sum } + +let convert (t : Parser.t) = + match t.learner.gradient_booster with + | Parser.Gbtree gbtree -> convert_trees t gbtree + | Parser.Gblinear _ -> assert false + | Parser.Dart _ -> assert false diff --git a/lib/xgboost/tree.mli b/lib/xgboost/tree.mli new file mode 100644 index 0000000000000000000000000000000000000000..7a07b8ee35de5d124da9ad5c562c3c9f0280c232 --- /dev/null +++ b/lib/xgboost/tree.mli @@ -0,0 +1,45 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +type tree = + | Split of { + split_indice : int; + split_condition : float; + left : tree; + right : tree; + missing : [ `Left ]; + } + | Leaf of { leaf_value : float } + +type op = + | Identity + | Sigmoid + +type t = { + base_score : float; + trees : tree array; + after_sum : op; +} +(** the value is [op(base_score + sum(tree))] *) + +val convert : Parser.t -> t +val predict : t -> Input.t -> float diff --git a/src/convert_xgboost.ml b/src/convert_xgboost.ml new file mode 100644 index 0000000000000000000000000000000000000000..c5f35bbfd537b7d47ab21b5d41a25d7de471d896 --- /dev/null +++ b/src/convert_xgboost.ml @@ -0,0 +1,145 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +open Base + +type var = { + value : Why3.Term.lsymbol; (** real *) + missing : Why3.Term.lsymbol; (** prop *) +} + +let convert_model env (model : Caisar_xgboost.Parser.t) task = + let th_real = Why3.Env.read_theory env [ "real" ] "Real" in + let task = Why3.Task.use_export task th_real in + let ls_lt = + Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "<" ]) + in + let ls_add = + Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "+" ]) + in + let tree = Caisar_xgboost.Tree.convert model in + let variables = + Array.init (Int.of_string model.learner.learner_model_param.num_feature) + ~f:(fun i -> + let name = + if i < Array.length model.learner.feature_names + then model.learner.feature_names.(i) + else Fmt.str "f%i" i + in + let id = Why3.Ident.id_fresh name in + let value = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in + let id = Why3.Ident.id_fresh ("missing_" ^ name) in + let missing = Why3.Term.create_psymbol id [] in + { value; missing }) + in + let task = + Array.fold variables ~init:task ~f:(fun task { value; missing } -> + let task = Why3.Task.add_param_decl task value in + let task = Why3.Task.add_param_decl task missing in + task) + in + let trees = + Array.mapi tree.trees ~f:(fun i tree -> + let id = Why3.Ident.id_fresh (Fmt.str "tree%i" i) in + let ls = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in + (ls, tree)) + in + let rec term_of_tree : Caisar_xgboost.Tree.tree -> Why3.Term.term = function + | Leaf { leaf_value } -> + Why3.Term.t_const + (Dataset.real_constant_of_float leaf_value) + Why3.Ty.ty_real + | Split { split_indice; split_condition; left; right; missing = `Left } -> + let var = + Why3.Term.fs_app variables.(split_indice).value [] Why3.Ty.ty_real + in + let missing = Why3.Term.ps_app variables.(split_indice).missing [] in + let value = + Why3.Term.t_const + (Dataset.real_constant_of_float split_condition) + Why3.Ty.ty_real + in + let cond = Why3.Term.ps_app ls_lt [ var; value ] in + let cond = Why3.Term.t_or missing cond in + let then_ = term_of_tree left in + let else_ = term_of_tree right in + Why3.Term.t_if cond then_ else_ + in + let task = + Array.fold trees ~init:task ~f:(fun task (ls, tree) -> + Why3.Task.add_logic_decl task + [ Why3.Decl.make_ls_defn ls [] (term_of_tree tree) ]) + in + let ls_sum = + let id = Why3.Ident.id_fresh "sum" in + Why3.Term.create_fsymbol id [] Why3.Ty.ty_real + in + let sum = + Array.fold trees + ~init: + (Why3.Term.t_const + (Dataset.real_constant_of_float tree.base_score) + Why3.Ty.ty_real) + ~f:(fun term (ls, _) -> + Why3.Term.fs_app ls_add + [ term; Why3.Term.fs_app ls [] Why3.Ty.ty_real ] + Why3.Ty.ty_real) + in + let task = + Why3.Task.add_logic_decl task [ Why3.Decl.make_ls_defn ls_sum [] sum ] + in + (task, variables) + +let convert_dataset mapping (data : Caisar_xgboost.Input.t) task = + let task = + Array.foldi mapping ~init:task ~f:(fun i task { value; missing } -> + let var = Why3.Term.fs_app value [] Why3.Ty.ty_real in + let missing = Why3.Term.ps_app missing [] in + let t = + match Caisar_xgboost.Input.get data i with + | None -> missing + | Some v -> + let value = + Why3.Term.t_const (Dataset.real_constant_of_float v) Why3.Ty.ty_real + in + Why3.Term.(t_and (t_not missing) (t_equ var value)) + in + let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in + Why3.Task.add_prop_decl task Paxiom pr t) + in + task + +let verify ?memlimit:_ ?timelimit:_ ~xgboost ~dataset () = + let env, _ = Verification.create_env ~debug:false [] in + let task = None in + let model = + let yojson = Yojson.Safe.from_file xgboost in + match Caisar_xgboost.Parser.of_yojson yojson with + | Error exn -> + Fmt.epr "Error: %s@." exn; + assert false + | Ok ok -> ok + in + let task, mapping = convert_model env model task in + let dataset = Caisar_xgboost.Input.of_filename model dataset in + let task = convert_dataset mapping (List.hd_exn dataset) task in + Why3.Pretty.print_task Caml.Format.std_formatter task diff --git a/src/convert_xgboost.mli b/src/convert_xgboost.mli new file mode 100644 index 0000000000000000000000000000000000000000..e3026599a6c6aa939a9412e4bf1d47e8eb7455b5 --- /dev/null +++ b/src/convert_xgboost.mli @@ -0,0 +1,30 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(* Copyright (C) 2022 *) +(* CEA (Commissariat à l'énergie atomique et aux énergies *) +(* alternatives) *) +(* *) +(* You can redistribute it and/or modify it under the terms of the GNU *) +(* Lesser General Public License as published by the Free Software *) +(* Foundation, version 2.1. *) +(* *) +(* It is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU Lesser General Public License for more details. *) +(* *) +(* See the GNU Lesser General Public License version 2.1 *) +(* for more details (enclosed in the file licenses/LGPLv2.1). *) +(* *) +(**************************************************************************) + +val verify : + ?memlimit:int -> + ?timelimit:int -> + xgboost:string -> + dataset:string -> + prover:Prover.t -> + unit -> + unit diff --git a/src/dataset.mli b/src/dataset.mli index a9fcef615ba5ea28ae15429713c022b4e6a5dffb..13efaec60f43ff152a5177a315fe93d3580f7d79 100644 --- a/src/dataset.mli +++ b/src/dataset.mli @@ -24,6 +24,7 @@ open Why3 type eps [@@deriving yojson, show] +val real_constant_of_float : float -> Constant.constant val string_of_eps : eps -> string val term_of_eps : Env.env -> eps -> Term.term diff --git a/src/dune b/src/dune index 58539f690189404eac7ee90665a3beb13508d1fb..6a71c0945ba7747d2e52fb9b8c977bfac548cd7a 100644 --- a/src/dune +++ b/src/dune @@ -19,7 +19,8 @@ why3 dune-site re - zarith) + zarith + caisar.xgboost) (preprocess (pps ppx_deriving_yojson diff --git a/src/main.ml b/src/main.ml index 77a4933e0692a1e07c800a77cbf25f74970dd5ff..3365059dadb23acfd441885c96146683773cf718 100644 --- a/src/main.ml +++ b/src/main.ml @@ -208,6 +208,11 @@ let interpret format loadpath files = let debug = log_level_is_debug () in List.iter ~f:(Interpretation.interpret ~debug ?format ~loadpath) files +let verify_xgboost ?memlimit ?timelimit xgboost dataset = + let memlimit = Option.map memlimit ~f:memlimit_of_string in + let timelimit = Option.map timelimit ~f:timelimit_of_string in + Convert_xgboost.verify ?memlimit ?timelimit ~xgboost ~dataset () + let exec_cmd cmdname cmd = Logs.debug (fun m -> m "Execution of command '%s'" cmdname); cmd () @@ -377,6 +382,31 @@ let interpret_cmd = in Cmd.v info term +let verify_xgboost_cmd = + let cmdname = "verify-xgboost" in + let info = + let doc = "EXPERIMENTAL: Property verification of xgboost file." in + Cmd.info cmdname ~sdocs:Manpage.s_common_options ~exits:Cmd.Exit.defaults + ~doc + ~man:[ `S Manpage.s_description; `P doc ] + in + let term = + let xgboost = + let doc = "xgboost json model file." in + Arg.(required & pos 0 (some file) None & info [] ~doc ~docv:"FILE") + in + let dataset = + let doc = "dataset file (csv, or svm)." in + Arg.(required & pos 1 (some file) None & info [] ~doc ~docv:"FILE") + in + Term.( + const (fun memlimit timelimit xgboost dataset _ -> + exec_cmd cmdname (fun () -> + verify_xgboost ?memlimit ?timelimit xgboost dataset)) + $ memlimit $ timelimit $ xgboost $ dataset $ setup_logs) + in + Cmd.v info term + let default_info = let doc = "A platform for characterizing the safety and robustness of artificial \ @@ -411,7 +441,13 @@ let () = let () = try Cmd.group ~default:default_cmd default_info - [ config_cmd; verify_cmd; verify_json_cmd; interpret_cmd ] + [ + config_cmd; + verify_cmd; + verify_json_cmd; + interpret_cmd; + verify_xgboost_cmd; + ] |> Cmd.eval ~catch:false |> Caml.exit with exn when not (log_level_is_debug ()) -> Logs.err (fun m -> m "@[%a@]" Why3.Exn_printer.exn_printer exn) diff --git a/src/verification.mli b/src/verification.mli index 0dabf9670ffe682990390fb261ee1926707ae4f4..95ebc1649cc095c1fbc245b405729a751a55fb73 100644 --- a/src/verification.mli +++ b/src/verification.mli @@ -86,3 +86,6 @@ val open_file : @param debug when set, enables debug information. @param format is the [file] format. @param loadpath is the additional loadpath. *) + +val create_env : + ?debug:bool -> string list -> Why3.Env.env * Why3.Whyconf.config diff --git a/tests/dune b/tests/dune index 8db19e46cb68c6ba9fc2e8eac42565dbcb5c0514..f9bb0f415d19f2d542b8c1a2c95ac8e605845eae 100644 --- a/tests/dune +++ b/tests/dune @@ -11,5 +11,8 @@ bin/cvc5 bin/nnenum.sh filter_tmpdir.sh - (glob_files "datasets/a/*")) + (glob_files "datasets/a/*") + filter_tmpdir.sh + ../lib/xgboost/example/agaricus.test.svm + ../lib/xgboost/example/model-0.json) (package caisar)) diff --git a/tests/xgboost.t b/tests/xgboost.t new file mode 100644 index 0000000000000000000000000000000000000000..42cec9f7be0dcc4ef59f90792f7f75b91f031a59 --- /dev/null +++ b/tests/xgboost.t @@ -0,0 +1,2 @@ +Test verify + $ caisar verify-xgboost ../lib/xgboost/example/model-0.json ../lib/xgboost/example/agaricus.test.svm 2>&1 <<EOF | ./filter_tmpdir.sh