Skip to content
Snippets Groups Projects
Commit b6a6049a authored by François Bobot's avatar François Bobot Committed by Michele Alberti
Browse files

[ir] Add new NIER AST.

parent 31c025cd
No related branches found
No related tags found
No related merge requests found
......@@ -20,36 +20,154 @@
(* *)
(**************************************************************************)
open Why3
open Base
type new_output = {
old_index : Why3.Number.int_constant;
old_nn : Language.nn;
new_index : int;
old_nn_args : Why3.Term.term list;
}
(** invariant:
- input_vars from 0 to its length - 1
- outputs from 0 to its length - 1 *)
(* let create_new_nn input_vars outputs = let id = ref (-1) in let mk ?name ~sh
~op ~op_p ?(pred=[]) ?(succ=[]) ?tensor = incr id; Ir.Nier_cfg.Node.create
~id:(!id) ~name:name ~pred ~succ ~tensor:tensor in let x =
Ir.Nier_cfg.NierCFGFloat.add_edge in outputs *)
(** Abstract terms that contains neural network application *)
let abstract_nn_term =
let rec do_simplify (new_index, output_vars) term =
match term.Why3.Term.t_node with
| Tapp
( get (* [ ] *),
[
{
t_node =
Why3.Term.Tapp
( ls1 (* @@ *),
[
{ t_node = Tapp (ls2 (* nn *), _); _ };
{
t_node =
Tapp (ls3 (* input vector *), tl (* input vars *));
_;
};
] );
_;
};
({ t_node = Tconst (ConstInt old_index); _ } as _t2);
] )
when String.equal get.ls_name.id_string (Why3.Ident.op_get "")
&& String.equal ls1.ls_name.id_string (Why3.Ident.op_infix "@@") -> (
match (Language.lookup_nn ls2, Language.lookup_vector ls3) with
| Some ({ nn_nb_inputs; nn_ty_elt; _ } as old_nn), Some vector_length ->
assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
let ls = Why3.(Term.create_fsymbol (Ident.id_fresh "y") [] nn_ty_elt) in
let term = Why3.Term.fs_app ls [] nn_ty_elt in
( ( new_index + 1,
({ old_index; new_index; old_nn; old_nn_args = tl }, ls)
:: output_vars ),
term )
| _, _ ->
failwith
(Fmt.str "nn application without fixed NeuralNetwork or Arguments: %a"
Why3.Pretty.print_term term))
| _ -> Why3.Term.t_map_fold do_simplify (new_index, output_vars) term
in
Why3.Trans.fold_map
(fun task_hd (((index, input_vars) as acc), task) ->
let tdecl = task_hd.task_decl in
match tdecl.td_node with
| Decl
{
d_node =
Dparam ({ ls_args = []; ls_constr = 0; ls_proj = false; _ } as ls);
_;
}
when Language.mem_nn ls ->
(* All neural networks are removed *) (acc, task)
| Decl
{
d_node =
Dparam ({ ls_args = []; ls_constr = 0; ls_proj = false; _ } as ls);
_;
} ->
(* Add meta for input variable declarations. Note that each meta needs
to appear before the corresponding declaration in order to be
leveraged by prover printers. *)
let task =
Why3.Task.add_meta task Meta.meta_input [ MAls ls; MAint index ]
in
let task = Why3.Task.add_tdecl task tdecl in
let index = index + 1 in
let input_vars = Why3.Term.Mls.add ls index input_vars in
((index, input_vars), task)
| Decl { d_node = Dprop (Pgoal, pr, goal); _ } ->
let (_, output_vars), goal = do_simplify (0, []) goal in
let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_clone pr.pr_name) in
let decl = Why3.Decl.create_prop_decl Pgoal pr goal in
(* Again, for each output variable, add the meta first, then its actual
declaration. *)
List.iter output_vars ~f:(fun (out, var) ->
ignore out.old_nn;
Fmt.epr "%a: %a -> %i: %a@." Why3.Pretty.print_ls var
Why3.(Number.print_int_constant Number.full_support)
out.old_index out.new_index
(Fmt.list ~sep:Fmt.comma Why3.Pretty.print_term)
out.old_nn_args);
let task =
List.fold output_vars ~init:task
~f:(fun task ({ new_index; _ }, output_var) ->
let task =
Why3.Task.add_meta task Meta.meta_output
[ MAls output_var; MAint new_index ]
in
let decl = Why3.Decl.create_param_decl output_var in
Why3.Task.add_decl task decl)
in
(acc, Why3.Task.add_decl task decl)
| Decl { d_node = Dprop (_, _, _); _ } ->
(* TODO: check no nn applications *)
(acc, Why3.Task.add_tdecl task tdecl)
| _ -> (acc, Why3.Task.add_tdecl task tdecl))
(0, Why3.Term.Mls.empty) None
(** {2 Old way} *)
(* Creates a list of pairs made of output variables and respective indices in
the list, for each neural network application to an input vector appearing in
a task. Such a list stands for the resulting output vector of a neural
network application to an input vector (ie, something of the form: nn @@ v).
The creation process is memoized wrt terms corresponding to neural network
applications to input vectors. *)
let create_output_vars =
let rec do_create mt (term : Term.term) =
let output_vars =
let rec create_output_var mt (term : Why3.Term.term) =
match term.t_node with
| Term.Tapp (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ])
when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> (
| Why3.Term.Tapp
(ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ])
when String.equal ls1.ls_name.id_string (Why3.Ident.op_infix "@@") -> (
match Language.lookup_nn ls2 with
| Some { nn_nb_outputs; nn_ty_elt; _ } ->
if Term.Mterm.mem term mt
if Why3.Term.Mterm.mem term mt
then mt
else
let output_vars =
List.init nn_nb_outputs ~f:(fun index ->
(index, Term.create_fsymbol (Ident.id_fresh "y") [] nn_ty_elt))
( index,
Why3.Term.create_fsymbol (Why3.Ident.id_fresh "y") [] nn_ty_elt
))
in
Term.Mterm.add term output_vars mt
Why3.Term.Mterm.add term output_vars mt
| None -> mt)
| _ -> Term.t_fold do_create mt term
| _ -> Why3.Term.t_fold create_output_var mt term
in
Trans.fold_decl
(fun decl mt -> Decl.decl_fold do_create mt decl)
Term.Mterm.empty
Why3.Trans.fold_decl
(fun decl mt -> Why3.Decl.decl_fold create_output_var mt decl)
Why3.Term.Mterm.empty
(* Simplifies a task goal exhibiting a vector selection on a neural network
application to an input vector (ie, (nn @@ v)[_]) by the corresponding output
......@@ -58,26 +176,27 @@ let create_output_vars =
all declared, each with a meta that describes the respective index in the
output vector. *)
let simplify_nn_application input_vars output_vars =
let rec do_simplify (term : Term.term) =
let rec do_simplify (term : Why3.Term.term) =
match term.t_node with
| Term.Tapp
| Why3.Term.Tapp
( ls_get (* [_] *),
[
({ t_node = Tapp (ls_atat (* @@ *), _); _ } as t1);
({ t_node = Tconst (ConstInt index); _ } as _t2);
] )
when String.equal ls_get.ls_name.id_string (Ident.op_get "")
&& String.equal ls_atat.ls_name.id_string (Ident.op_infix "@@") -> (
match Term.Mterm.find_opt t1 output_vars with
| None -> Term.t_map do_simplify term
when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "")
&& String.equal ls_atat.ls_name.id_string (Why3.Ident.op_infix "@@")
-> (
match Why3.Term.Mterm.find_opt t1 output_vars with
| None -> Why3.Term.t_map do_simplify term
| Some output_vars ->
let index = Number.to_small_integer index in
let index = Why3.Number.to_small_integer index in
assert (index < List.length output_vars);
let ls = Stdlib.List.assoc index output_vars in
Term.t_app_infer ls [])
| _ -> Term.t_map do_simplify term
Why3.Term.t_app_infer ls [])
| _ -> Why3.Term.t_map do_simplify term
in
Trans.fold
Why3.Trans.fold
(fun task_hd task ->
match task_hd.task_decl.td_node with
| Decl { d_node = Dparam ls; _ } ->
......@@ -85,38 +204,40 @@ let simplify_nn_application input_vars output_vars =
to appear before the corresponding declaration in order to be
leveraged by prover printers. *)
let task =
match Term.Mls.find_opt ls input_vars with
match Why3.Term.Mls.find_opt ls input_vars with
| None -> task
| Some index ->
Task.add_meta task Meta.meta_input [ MAls ls; MAint index ]
Why3.Task.add_meta task Meta.meta_input [ MAls ls; MAint index ]
in
Task.add_tdecl task task_hd.task_decl
Why3.Task.add_tdecl task task_hd.task_decl
| Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) ->
let decl = Decl.decl_map do_simplify decl in
let decl = Why3.Decl.decl_map do_simplify decl in
let task =
(* Output variables are not declared yet in the task as they are
created on the fly for each (different) neural network application
on an input vector. We add here their declarations in the task. *)
Term.Mterm.fold
Why3.Term.Mterm.fold
(fun _t output_vars task ->
(* Again, for each output variable, add the meta first, then its
actual declaration. *)
List.fold output_vars ~init:task
~f:(fun task (index, output_var) ->
let task =
Task.add_meta task Meta.meta_output
Why3.Task.add_meta task Meta.meta_output
[ MAls output_var; MAint index ]
in
let decl = Decl.create_param_decl output_var in
Task.add_decl task decl))
let decl = Why3.Decl.create_param_decl output_var in
Why3.Task.add_decl task decl))
output_vars task
in
Task.add_decl task decl
Why3.Task.add_decl task decl
| Use _ | Clone _ | Meta _ | Decl _ ->
Task.add_tdecl task task_hd.task_decl)
Why3.Task.add_tdecl task task_hd.task_decl)
None
let trans =
Trans.bind Utils.input_vars (fun input_vars ->
Trans.bind create_output_vars (fun output_vars ->
simplify_nn_application input_vars output_vars))
Why3.Trans.bind Utils.input_terms (function
| Utils.Others -> abstract_nn_term
| Vars input_vars ->
Why3.Trans.bind output_vars (fun output_vars ->
simplify_nn_application input_vars output_vars))
......@@ -22,8 +22,6 @@
open Base
let src = Logs.Src.create "TransformationsUtils" ~doc:"Transformation utils"
let float_of_real_constant rc =
let is_neg, rc =
( Why3.(BigInt.lt rc.Number.rl_real.rv_sig BigInt.zero),
......@@ -36,22 +34,26 @@ let float_of_real_constant rc =
(* Collects input variables and their position inside input vectors. Such
process is memoized wrt lsymbols corresponding to input vectors. *)
type position_input_vars =
(int Why3.Term.Mls.t
[@printer
fun fmt mls ->
Why3.(
Pp.print_iter2 Term.Mls.iter Pp.newline Pp.comma Pretty.print_ls Pp.int
fmt mls)])
(* Terms forming vectors in input to models. *)
type input_terms =
| Vars of
(int Why3.Term.Mls.t
[@printer
fun fmt mls ->
Why3.(
Pp.print_iter2 Term.Mls.iter Pp.newline Pp.comma Pretty.print_ls
Pp.int fmt mls)])
(* A map from input variable lsymbols to corresponding positions in input
vectors. *)
| Others (* Generic terms. *)
[@@deriving show]
let input_vars : position_input_vars Why3.Trans.trans =
let input_terms : input_terms Why3.Trans.trans =
let exception NotInputVariable in
let hls = Why3.Term.Hls.create 13 in
let add index mls = function
| { Why3.Term.t_node = Tapp (ls, []); _ } -> Why3.Term.Mls.add ls index mls
| t ->
Logging.code_error ~src (fun m ->
m "Not an input variable: %a" Why3.Pretty.print_term t)
| _ -> raise NotInputVariable
in
let rec do_collect mls t =
match t.Why3.Term.t_node with
......@@ -72,8 +74,13 @@ let input_vars : position_input_vars Why3.Trans.trans =
| _ -> Why3.Term.t_fold do_collect mls t
in
Why3.Trans.fold_decl
(fun decl mls -> Why3.Decl.decl_fold do_collect mls decl)
Why3.Term.Mls.empty
(fun decl mls ->
match mls with
| Vars mls -> (
try Vars (Why3.Decl.decl_fold do_collect mls decl)
with NotInputVariable -> Others)
| Others -> Others)
(Vars Why3.Term.Mls.empty)
(* Collects input feature values (these are typically coming from a data point
of a data set).
......
......@@ -20,8 +20,13 @@
(* *)
(**************************************************************************)
type position_input_vars = int Why3.Term.Mls.t [@@deriving show]
(** Map input variable lsymbols to corresponding position in input vectors. *)
(** Terms forming vectors in input to models. *)
type input_terms =
| Vars of int Why3.Term.Mls.t
(** A map from input variable lsymbols to corresponding positions in input
vectors. *)
| Others (** Generic terms. *)
[@@deriving show]
type features = Float.t Why3.Term.Mls.t [@@deriving show]
(** Input feature values. *)
......@@ -34,7 +39,7 @@ and interval = float option * float option [@@deriving show]
type label = int [@@deriving show]
(** Output label. *)
val input_vars : position_input_vars Why3.Trans.trans
val input_terms : input_terms Why3.Trans.trans
val input_features :
Why3.Env.env -> vars:Why3.Term.lsymbol list -> features Why3.Trans.trans
......
......@@ -273,7 +273,12 @@ let answer_saver limit config env config_prover ~proof_strategy task =
(svm_filename, Language.svm_abstraction_of_string svm_abstraction)
| Some _ -> assert false (* By construction of the meta. *)
in
let vars = Why3.Term.Mls.keys (Trans.apply Utils.input_vars task) in
let vars =
match Trans.apply Utils.input_terms task with
| Utils.Vars mls -> Why3.Term.Mls.keys mls
| Others ->
Logging.user_error (fun m -> m "Cannot determine input variables")
in
let dataset : Csv.t =
let features = Trans.apply (Utils.input_features env ~vars) task in
let features =
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment