Skip to content
Snippets Groups Projects
transformations.ml 3.84 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(**************************************************************************)

open Base

let meta_input =
  Why3.Theory.(
    register_meta "caisar_input"
      ~desc:"Indicates the position of the input in the neural network"
      [ MTlsymbol; MTint ])

let meta_output =
  Why3.Theory.(
    register_meta "caisar_output"
      ~desc:"Indicates the position of the output in the neural network"
      [ MTlsymbol; MTint ])

let get_input_variables =
  let rec aux acc (term : Why3.Term.term) =
    match term.t_node with
    | Why3.Term.Tapp (ls, args) -> (
      match Language.lookup_loaded_nnets ls with
      | None -> acc
      | Some _ ->
        let add i acc = function
          | { Why3.Term.t_node = Tapp (vs, []); _ } ->
            Why3.Term.Mls.add vs i acc
          | arg ->
            invalid_arg
              (Fmt.str "No direct variable in application: %a"
                 Why3.Pretty.print_term arg)
        in
        List.foldi ~init:acc ~f:add args)
    | _ -> Why3.Term.t_fold aux acc term
  in
  Why3.Trans.fold_decl
    (fun decl acc -> Why3.Decl.decl_fold aux acc decl)
    Why3.Term.Mls.empty

let simplify_goal env input_variables =
  let rec aux hls (term : Why3.Term.term) =
    match term.t_node with
    | Why3.Term.Tapp (ls, _) -> (
      match Language.lookup_loaded_nnets ls with
      | None -> Why3.Term.t_map (aux hls) term
      | Some nnet ->
        let outputs =
          List.init nnet.nb_outputs ~f:(fun i ->
            let open Why3 in
            let id = Ident.id_fresh "y" in
            let ls = Term.create_fsymbol id [] nnet.ty_data in
            hls := (Why3.Decl.create_param_decl ls, ls, i) :: !hls;
            Term.fs_app ls [] nnet.ty_data)
        in
        Why3.Term.t_tuple outputs)
    | _ -> Why3.Term.t_map (aux hls) term
  in
  Why3.Trans.fold
    (fun task_hd acc ->
      match task_hd.task_decl.td_node with
      | Use _ | Clone _ | Meta _ -> Why3.Task.add_tdecl acc task_hd.task_decl
      | Decl { d_node = Dparam ls; _ } -> (
        let task = Why3.Task.add_tdecl acc task_hd.task_decl in
        match Why3.Term.Mls.find_opt ls input_variables with
        | None -> task
        | Some pos -> Why3.Task.add_meta task meta_input [ MAls ls; MAint pos ])
      | Decl decl ->
        let hls = ref [] in
        let map term =
          let term = aux hls term in
          if List.is_empty !hls
          then term
          else
            let known =
              List.fold !hls ~init:task_hd.task_known ~f:(fun acc (d, _, _) ->
                Why3.Decl.known_add_decl acc d)
            in
            let engine =
              Why3.Reduction_engine.create
                {
                  compute_defs = false;
                  compute_builtin = true;
                  compute_def_set = Why3.Term.Sls.empty;
                }
                env known
            in
            Why3.Reduction_engine.normalize ~limit:100 engine
              Why3.Term.Mvs.empty term
        in
        let decl = Why3.Decl.decl_map map decl in
        let acc =
          List.fold !hls ~init:acc ~f:(fun acc (d, ls, i) ->
            let task = Why3.Task.add_decl acc d in
            Why3.Task.add_meta task meta_output [ MAls ls; MAint i ])
        in
        Why3.Task.add_decl acc decl)
    None

let caisar_native_prover env =
  Why3.Trans.seq
    [
      Why3.Trans.bind get_input_variables (simplify_goal env)
      (* Why3.Simplify_formula.simplify_; *);
    ]

let init () =
  Why3.Trans.register_env_transform
    ~desc:"Transformation for provers that support loading neural networks."
    "caisar_native_prover" caisar_native_prover