convert_xgboost.ml 5.79 KiB
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Base
type var = {
value : Why3.Term.lsymbol; (** real *)
missing : Why3.Term.lsymbol; (** prop *)
}
let convert_model env (model : Caisar_xgboost.Parser.t) task =
let th_real = Why3.Env.read_theory env [ "real" ] "Real" in
let task = Why3.Task.use_export task th_real in
let ls_lt =
Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "<" ])
in
let ls_add =
Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "+" ])
in
let tree = Caisar_xgboost.Tree.convert model in
let variables =
Array.init (Int.of_string model.learner.learner_model_param.num_feature)
~f:(fun i ->
let name =
if i < Array.length model.learner.feature_names
then model.learner.feature_names.(i)
else Fmt.str "f%i" i
in
let id = Why3.Ident.id_fresh name in
let value = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
let id = Why3.Ident.id_fresh ("missing_" ^ name) in
let missing = Why3.Term.create_psymbol id [] in
{ value; missing })
in
let task =
Array.fold variables ~init:task ~f:(fun task { value; missing } ->
let task = Why3.Task.add_param_decl task value in
let task = Why3.Task.add_param_decl task missing in
task)
in
let trees =
Array.mapi tree.trees ~f:(fun i tree ->
let id = Why3.Ident.id_fresh (Fmt.str "tree%i" i) in
let ls = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
(ls, tree))
in
let rec term_of_tree : Caisar_xgboost.Tree.tree -> Why3.Term.term = function
| Leaf { leaf_value } ->
Why3.Term.t_const
(Dataset.real_constant_of_float leaf_value)
Why3.Ty.ty_real
| Split { split_indice; split_condition; left; right; missing = `Left } ->
let var =
Why3.Term.fs_app variables.(split_indice).value [] Why3.Ty.ty_real
in
let missing = Why3.Term.ps_app variables.(split_indice).missing [] in
let value =
Why3.Term.t_const
(Dataset.real_constant_of_float split_condition)
Why3.Ty.ty_real
in
let cond = Why3.Term.ps_app ls_lt [ var; value ] in
let cond = Why3.Term.t_or missing cond in
let then_ = term_of_tree left in
let else_ = term_of_tree right in
Why3.Term.t_if cond then_ else_
in
let task =
Array.fold trees ~init:task ~f:(fun task (ls, tree) ->
Why3.Task.add_logic_decl task
[ Why3.Decl.make_ls_defn ls [] (term_of_tree tree) ])
in
let ls_sum =
let id = Why3.Ident.id_fresh "sum" in
Why3.Term.create_fsymbol id [] Why3.Ty.ty_real
in
let sum =
Array.fold trees
~init:
(Why3.Term.t_const
(Dataset.real_constant_of_float tree.base_score)
Why3.Ty.ty_real)
~f:(fun term (ls, _) ->
Why3.Term.fs_app ls_add
[ term; Why3.Term.fs_app ls [] Why3.Ty.ty_real ]
Why3.Ty.ty_real)
in
let task =
Why3.Task.add_logic_decl task [ Why3.Decl.make_ls_defn ls_sum [] sum ]
in
(task, variables)
let convert_dataset mapping (data : Caisar_xgboost.Input.t) task =
let task =
Array.foldi mapping ~init:task ~f:(fun i task { value; missing } ->
let var = Why3.Term.fs_app value [] Why3.Ty.ty_real in
let missing = Why3.Term.ps_app missing [] in
let t =
match Caisar_xgboost.Input.get data i with
| None -> missing
| Some v ->
let value =
Why3.Term.t_const (Dataset.real_constant_of_float v) Why3.Ty.ty_real
in
Why3.Term.(t_and (t_not missing) (t_equ var value))
in
let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
Why3.Task.add_prop_decl task Paxiom pr t)
in
task
let verify ?memlimit:_ ?timelimit:_ ~xgboost ~dataset () =
let env, _ = Verification.create_env ~debug:false [] in
let task = None in
let model =
let yojson = Yojson.Safe.from_file xgboost in
match Caisar_xgboost.Parser.of_yojson yojson with
| Error exn ->
Fmt.epr "Error: %s@." exn;
assert false
| Ok ok -> ok
in
let task, mapping = convert_model env model task in
let dataset = Caisar_xgboost.Input.of_filename model dataset in
let task = convert_dataset mapping (List.hd_exn dataset) task in
Why3.Pretty.print_task Caml.Format.std_formatter task