-
Michele Alberti authoredMichele Alberti authored
transformations.ml 3.84 KiB
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
open Base
let meta_input =
Why3.Theory.(
register_meta "caisar_input"
~desc:"Indicates the position of the input in the neural network"
[ MTlsymbol; MTint ])
let meta_output =
Why3.Theory.(
register_meta "caisar_output"
~desc:"Indicates the position of the output in the neural network"
[ MTlsymbol; MTint ])
let get_input_variables =
let rec aux acc (term : Why3.Term.term) =
match term.t_node with
| Why3.Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nnets ls with
| None -> acc
| Some _ ->
let add i acc = function
| { Why3.Term.t_node = Tapp (vs, []); _ } ->
Why3.Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a"
Why3.Pretty.print_term arg)
in
List.foldi ~init:acc ~f:add args)
| _ -> Why3.Term.t_fold aux acc term
in
Why3.Trans.fold_decl
(fun decl acc -> Why3.Decl.decl_fold aux acc decl)
Why3.Term.Mls.empty
let simplify_goal env input_variables =
let rec aux hls (term : Why3.Term.term) =
match term.t_node with
| Why3.Term.Tapp (ls, _) -> (
match Language.lookup_loaded_nnets ls with
| None -> Why3.Term.t_map (aux hls) term
| Some nnet ->
let outputs =
List.init nnet.nb_outputs ~f:(fun i ->
let open Why3 in
let id = Ident.id_fresh "y" in
let ls = Term.create_fsymbol id [] nnet.ty_data in
hls := (Why3.Decl.create_param_decl ls, ls, i) :: !hls;
Term.fs_app ls [] nnet.ty_data)
in
Why3.Term.t_tuple outputs)
| _ -> Why3.Term.t_map (aux hls) term
in
Why3.Trans.fold
(fun task_hd acc ->
match task_hd.task_decl.td_node with
| Use _ | Clone _ | Meta _ -> Why3.Task.add_tdecl acc task_hd.task_decl
| Decl { d_node = Dparam ls; _ } -> (
let task = Why3.Task.add_tdecl acc task_hd.task_decl in
match Why3.Term.Mls.find_opt ls input_variables with
| None -> task
| Some pos -> Why3.Task.add_meta task meta_input [ MAls ls; MAint pos ])
| Decl decl ->
let hls = ref [] in
let map term =
let term = aux hls term in
if List.is_empty !hls
then term
else
let known =
List.fold !hls ~init:task_hd.task_known ~f:(fun acc (d, _, _) ->
Why3.Decl.known_add_decl acc d)
in
let engine =
Why3.Reduction_engine.create
{
compute_defs = false;
compute_builtin = true;
compute_def_set = Why3.Term.Sls.empty;
}
env known
in
Why3.Reduction_engine.normalize ~limit:100 engine
Why3.Term.Mvs.empty term
in
let decl = Why3.Decl.decl_map map decl in
let acc =
List.fold !hls ~init:acc ~f:(fun acc (d, ls, i) ->
let task = Why3.Task.add_decl acc d in
Why3.Task.add_meta task meta_output [ MAls ls; MAint i ])
in
Why3.Task.add_decl acc decl)
None
let caisar_native_prover env =
Why3.Trans.seq
[
Why3.Trans.bind get_input_variables (simplify_goal env)
(* Why3.Simplify_formula.simplify_; *);
]
let init () =
Why3.Trans.register_env_transform
~desc:"Transformation for provers that support loading neural networks."
"caisar_native_prover" caisar_native_prover