Skip to content
Snippets Groups Projects
convert_xgboost.ml 5.79 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2022                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base

type var = {
  value : Why3.Term.lsymbol;  (** real *)
  missing : Why3.Term.lsymbol;  (** prop *)
}

let convert_model env (model : Caisar_xgboost.Parser.t) task =
  let th_real = Why3.Env.read_theory env [ "real" ] "Real" in
  let task = Why3.Task.use_export task th_real in
  let ls_lt =
    Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "<" ])
  in
  let ls_add =
    Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "+" ])
  in
  let tree = Caisar_xgboost.Tree.convert model in
  let variables =
    Array.init (Int.of_string model.learner.learner_model_param.num_feature)
      ~f:(fun i ->
      let name =
        if i < Array.length model.learner.feature_names
        then model.learner.feature_names.(i)
        else Fmt.str "f%i" i
      in
      let id = Why3.Ident.id_fresh name in
      let value = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
      let id = Why3.Ident.id_fresh ("missing_" ^ name) in
      let missing = Why3.Term.create_psymbol id [] in
      { value; missing })
  in
  let task =
    Array.fold variables ~init:task ~f:(fun task { value; missing } ->
      let task = Why3.Task.add_param_decl task value in
      let task = Why3.Task.add_param_decl task missing in
      task)
  in
  let trees =
    Array.mapi tree.trees ~f:(fun i tree ->
      let id = Why3.Ident.id_fresh (Fmt.str "tree%i" i) in
      let ls = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
      (ls, tree))
  in
  let rec term_of_tree : Caisar_xgboost.Tree.tree -> Why3.Term.term = function
    | Leaf { leaf_value } ->
      Why3.Term.t_const
        (Dataset.real_constant_of_float leaf_value)
        Why3.Ty.ty_real
    | Split { split_indice; split_condition; left; right; missing = `Left } ->
      let var =
        Why3.Term.fs_app variables.(split_indice).value [] Why3.Ty.ty_real
      in
      let missing = Why3.Term.ps_app variables.(split_indice).missing [] in
      let value =
        Why3.Term.t_const
          (Dataset.real_constant_of_float split_condition)
          Why3.Ty.ty_real
      in
      let cond = Why3.Term.ps_app ls_lt [ var; value ] in
      let cond = Why3.Term.t_or missing cond in
      let then_ = term_of_tree left in
      let else_ = term_of_tree right in
      Why3.Term.t_if cond then_ else_
  in
  let task =
    Array.fold trees ~init:task ~f:(fun task (ls, tree) ->
      Why3.Task.add_logic_decl task
        [ Why3.Decl.make_ls_defn ls [] (term_of_tree tree) ])
  in
  let ls_sum =
    let id = Why3.Ident.id_fresh "sum" in
    Why3.Term.create_fsymbol id [] Why3.Ty.ty_real
  in
  let sum =
    Array.fold trees
      ~init:
        (Why3.Term.t_const
           (Dataset.real_constant_of_float tree.base_score)
           Why3.Ty.ty_real)
      ~f:(fun term (ls, _) ->
        Why3.Term.fs_app ls_add
          [ term; Why3.Term.fs_app ls [] Why3.Ty.ty_real ]
          Why3.Ty.ty_real)
  in
  let task =
    Why3.Task.add_logic_decl task [ Why3.Decl.make_ls_defn ls_sum [] sum ]
  in
  (task, variables)

let convert_dataset mapping (data : Caisar_xgboost.Input.t) task =
  let task =
    Array.foldi mapping ~init:task ~f:(fun i task { value; missing } ->
      let var = Why3.Term.fs_app value [] Why3.Ty.ty_real in
      let missing = Why3.Term.ps_app missing [] in
      let t =
        match Caisar_xgboost.Input.get data i with
        | None -> missing
        | Some v ->
          let value =
            Why3.Term.t_const (Dataset.real_constant_of_float v) Why3.Ty.ty_real
          in
          Why3.Term.(t_and (t_not missing) (t_equ var value))
      in
      let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
      Why3.Task.add_prop_decl task Paxiom pr t)
  in
  task

let verify ?memlimit:_ ?timelimit:_ ~xgboost ~dataset () =
  let env, _ = Verification.create_env ~debug:false [] in
  let task = None in
  let model =
    let yojson = Yojson.Safe.from_file xgboost in
    match Caisar_xgboost.Parser.of_yojson yojson with
    | Error exn ->
      Fmt.epr "Error: %s@." exn;
      assert false
    | Ok ok -> ok
  in
  let task, mapping = convert_model env model task in
  let dataset = Caisar_xgboost.Input.of_filename model dataset in
  let task = convert_dataset mapping (List.hd_exn dataset) task in
  Why3.Pretty.print_task Caml.Format.std_formatter task