Skip to content
Snippets Groups Projects
convert_xgboost.ml 10.9 KiB
Newer Older
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base

type var = {
  value : Why3.Term.lsymbol;  (** real *)
  missing : Why3.Term.lsymbol;  (** prop *)
}

type mapping = {
  variables : var Array.t;
  sum : Why3.Term.term;
}

let convert_model env (model : Caisar_xgboost.Parser.t) task =
  let th_real = Why3.Env.read_theory env [ "real" ] "Real" in
  let task = Why3.Task.use_export task th_real in
  let ls_lt =
    Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "<" ])
  in
  let ls_add =
    Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "+" ])
  in
  let tree = Caisar_xgboost.Tree.convert model in
  let variables =
    Array.init (Int.of_string model.learner.learner_model_param.num_feature)
      ~f:(fun i ->
      let name =
        if i < Array.length model.learner.feature_names
        then model.learner.feature_names.(i)
        else Fmt.str "f%i" i
      in
      let id = Why3.Ident.id_fresh name in
      let value = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
      let id = Why3.Ident.id_fresh ("missing_" ^ name) in
      let missing = Why3.Term.create_psymbol id [] in
      { value; missing })
  in
  let task =
    Array.fold variables ~init:task ~f:(fun task { value; missing } ->
      let task = Why3.Task.add_param_decl task value in
      let task = Why3.Task.add_param_decl task missing in
      task)
  in
  let trees =
    Array.mapi tree.trees ~f:(fun i tree ->
      let id = Why3.Ident.id_fresh (Fmt.str "tree%i" i) in
      let ls = Why3.Term.create_fsymbol id [] Why3.Ty.ty_real in
      (ls, tree))
  in
  let rec term_of_tree : Caisar_xgboost.Tree.tree -> Why3.Term.term = function
    | Leaf { leaf_value } ->
      Why3.Term.t_const
        (Dataset.real_constant_of_float leaf_value)
        Why3.Ty.ty_real
    | Split { split_indice; split_condition; left; right; missing = `Left } ->
      let var =
        Why3.Term.fs_app variables.(split_indice).value [] Why3.Ty.ty_real
      in
      let missing = Why3.Term.ps_app variables.(split_indice).missing [] in
      let value =
        Why3.Term.t_const
          (Dataset.real_constant_of_float split_condition)
          Why3.Ty.ty_real
      in
      let cond = Why3.Term.ps_app ls_lt [ var; value ] in
      let cond = Why3.Term.t_or missing cond in
      let then_ = term_of_tree left in
      let else_ = term_of_tree right in
      Why3.Term.t_if cond then_ else_
  in
  let task =
    Array.fold trees ~init:task ~f:(fun task (ls, tree) ->
      Why3.Task.add_logic_decl task
        [ Why3.Decl.make_ls_defn ls [] (term_of_tree tree) ])
  in
  let ls_sum =
    let id = Why3.Ident.id_fresh "sum" in
    Why3.Term.create_fsymbol id [] Why3.Ty.ty_real
  in
  let sum =
    Array.fold trees
      ~init:
        (Why3.Term.t_const
           (Dataset.real_constant_of_float tree.base_score)
           Why3.Ty.ty_real)
      ~f:(fun term (ls, _) ->
        Why3.Term.fs_app ls_add
          [ term; Why3.Term.fs_app ls [] Why3.Ty.ty_real ]
          Why3.Ty.ty_real)
  in
  let task =
    Why3.Task.add_logic_decl task [ Why3.Decl.make_ls_defn ls_sum [] sum ]
  in
  (task, { variables; sum = Why3.Term.fs_app ls_sum [] Why3.Ty.ty_real })
let _convert_dataset mapping (data : Caisar_xgboost.Input.t) task =
    Array.foldi mapping.variables ~init:task
      ~f:(fun i task { value; missing } ->
      let var = Why3.Term.fs_app value [] Why3.Ty.ty_real in
      let missing = Why3.Term.ps_app missing [] in
      let t =
        match Caisar_xgboost.Input.get data i with
        | None -> missing
        | Some v ->
          let value =
            Why3.Term.t_const (Dataset.real_constant_of_float v) Why3.Ty.ty_real
          in
          Why3.Term.(t_and (t_not missing) (t_equ var value))
      in
      let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
      Why3.Task.add_prop_decl task Paxiom pr t)
  in
  task

type over = {
  low : Float.t;
  high : Float.t;
  missing : bool;
}

let initial_over (model : Caisar_xgboost.Parser.t) () =
  let over =
    Array.create
      ~len:
        (Int.of_string
           model.Caisar_xgboost.Parser.learner.learner_model_param.num_feature)
      {
        low = Float.max_finite_value;
        high = -.Float.max_finite_value;
        missing = false;
      }
  in
  over

let compute_bounds_of_input over (data : Caisar_xgboost.Input.t) =
  let task =
    Array.iteri over ~f:(fun i _ ->
      match Caisar_xgboost.Input.get data i with
      | None -> over.(i) <- { (over.(i)) with missing = true }
      | Some v ->
        over.(i) <-
          {
            (over.(i)) with
            low = Float.min over.(i).low v;
            high = Float.max over.(i).high v;
          })
  in
  task

let bound_term ?(epsilon = 0.0) env ~low ~high t =
  let v_low, v_high = (low -. epsilon, high +. epsilon) in
  let v_low =
    Why3.Term.t_const (Dataset.real_constant_of_float v_low) Why3.Ty.ty_real
  in
  let v_high =
    Why3.Term.t_const (Dataset.real_constant_of_float v_high) Why3.Ty.ty_real
  in
  let th_real = Why3.Env.read_theory env [ "real" ] "Real" in
  let ls_lt =
    Why3.Theory.(ns_find_ls th_real.th_export [ Why3.Ident.op_infix "<=" ])
  in
  Why3.Term.(t_and (ps_app ls_lt [ v_low; t ]) (ps_app ls_lt [ t; v_high ]))

let convert_bounds env mapping over task =
  let task =
    Array.foldi mapping.variables ~init:task
      ~f:(fun i task { value; missing } ->
      let var = Why3.Term.fs_app value [] Why3.Ty.ty_real in
      let missing = Why3.Term.ps_app missing [] in
      let t =
        let { low; high; missing = can_be_missing } = over.(i) in
        let can_be_missing =
          if can_be_missing then missing else Why3.Term.t_false
        in
        let value =
          Why3.Term.(t_and (t_not missing) (bound_term env ~low ~high var))
        in
        Why3.Term.(t_or can_be_missing value)
      in
      let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
      Why3.Task.add_prop_decl task Paxiom pr t)
  in
  task

let compute_bounds_of_result (model : Caisar_xgboost.Parser.t)
  (data : Caisar_xgboost.Input.t) over =
  let v = Caisar_xgboost.Predict.predict model data in
  { low = Float.min v over.low; high = Float.max v over.high; missing = false }

let _convert_result env mapping (model : Caisar_xgboost.Parser.t)
  (data : Caisar_xgboost.Input.t) task =
  let v = Caisar_xgboost.Predict.predict model data in
  let t = bound_term ~epsilon:0.1 ~low:v ~high:v env mapping.sum in
  let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
  Why3.Task.add_decl task (Why3.Decl.create_prop_decl Pgoal pr t)

let call_prover_on_task limit config command driver task =
  let prover_call =
    Why3.Driver.prove_task ~command ~config ~limit driver task
  in
  let prover_result = Why3.Call_provers.wait_on_call prover_call in
  prover_result

let prove_and_print ?memlimit ?timelimit env config prover task =
  let main = Why3.Whyconf.get_main config in
  let limit =
    let memlimit =
      Option.value memlimit ~default:(Why3.Whyconf.memlimit main)
    in
    let timelimit =
      Option.value_map timelimit ~f:Float.of_int
        ~default:(Why3.Whyconf.timelimit main)
    in
    let def = Why3.Call_provers.empty_limit in
    {
      Why3.Call_provers.limit_time = timelimit;
      limit_steps = Why3.Opt.get_def def.limit_steps None;
      limit_mem = memlimit;
    }
  in
  let config_prover =
    let prover = Prover.to_string prover in
    Why3.Whyconf.(filter_one_prover config (mk_filter_prover ~altern:"" prover))
  in
  let driver = Why3.Driver.load_driver_for_prover main env config_prover in
  (* let open Why3 in match String.chop_prefix ~prefix:"caisar_drivers/"
     config_prover.driver with | None -> Driver.load_driver_for_prover main env
     config_prover | Some file -> let file = Stdlib.Filename.concat
     (Stdlib.Filename.concat (List.hd_exn Dirs.Sites.config) "drivers") file in
     Driver.load_driver_file_and_extras main env file
     config_prover.extra_drivers in *)
  let command =
    Why3.Whyconf.get_complete_command ~with_steps:false config_prover
  in
  let answer = call_prover_on_task limit main command driver task in
  Fmt.pr "result: %a@." Why3.Call_provers.print_prover_answer answer.pr_answer

let verify ?memlimit ?timelimit ~xgboost ~dataset ~prover () =
  let env, config = Verification.create_env [] in
  let th_real = Why3.Env.read_theory env [ "real" ] "Real" in
  let task = Why3.Task.use_export task th_real in
  let model =
    let yojson = Yojson.Safe.from_file xgboost in
    match Caisar_xgboost.Parser.of_yojson yojson with
    | Error exn ->
      Fmt.epr "Error: %s@." exn;
      assert false
    | Ok ok -> ok
  in
  let task, mapping = convert_model env model task in
  let dataset = Caisar_xgboost.Input.of_filename model dataset in
  let over = initial_over model () in
  List.iter dataset ~f:(compute_bounds_of_input over);
  let task = convert_bounds env mapping over task in
  let over_result =
    List.fold dataset
      ~init:
        {
          low = Float.max_finite_value;
          high = -.Float.max_finite_value;
          missing = false;
        } ~f:(fun over input -> compute_bounds_of_result model input over)
  in
  let task =
    let t =
      bound_term ~epsilon:0. ~low:over_result.low ~high:over_result.high env
        mapping.sum
    in
    let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_fresh "data") in
    Why3.Task.add_decl task (Why3.Decl.create_prop_decl Pgoal pr t)
  in
  prove_and_print ?memlimit ?timelimit env config prover task