Skip to content
Snippets Groups Projects
Commit b6a6049a authored by François Bobot's avatar François Bobot Committed by Michele Alberti
Browse files

[ir] Add new NIER AST.

parent 31c025cd
No related branches found
No related tags found
No related merge requests found
...@@ -20,36 +20,154 @@ ...@@ -20,36 +20,154 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
open Why3
open Base open Base
type new_output = {
old_index : Why3.Number.int_constant;
old_nn : Language.nn;
new_index : int;
old_nn_args : Why3.Term.term list;
}
(** invariant:
- input_vars from 0 to its length - 1
- outputs from 0 to its length - 1 *)
(* let create_new_nn input_vars outputs = let id = ref (-1) in let mk ?name ~sh
~op ~op_p ?(pred=[]) ?(succ=[]) ?tensor = incr id; Ir.Nier_cfg.Node.create
~id:(!id) ~name:name ~pred ~succ ~tensor:tensor in let x =
Ir.Nier_cfg.NierCFGFloat.add_edge in outputs *)
(** Abstract terms that contains neural network application *)
let abstract_nn_term =
let rec do_simplify (new_index, output_vars) term =
match term.Why3.Term.t_node with
| Tapp
( get (* [ ] *),
[
{
t_node =
Why3.Term.Tapp
( ls1 (* @@ *),
[
{ t_node = Tapp (ls2 (* nn *), _); _ };
{
t_node =
Tapp (ls3 (* input vector *), tl (* input vars *));
_;
};
] );
_;
};
({ t_node = Tconst (ConstInt old_index); _ } as _t2);
] )
when String.equal get.ls_name.id_string (Why3.Ident.op_get "")
&& String.equal ls1.ls_name.id_string (Why3.Ident.op_infix "@@") -> (
match (Language.lookup_nn ls2, Language.lookup_vector ls3) with
| Some ({ nn_nb_inputs; nn_ty_elt; _ } as old_nn), Some vector_length ->
assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
let ls = Why3.(Term.create_fsymbol (Ident.id_fresh "y") [] nn_ty_elt) in
let term = Why3.Term.fs_app ls [] nn_ty_elt in
( ( new_index + 1,
({ old_index; new_index; old_nn; old_nn_args = tl }, ls)
:: output_vars ),
term )
| _, _ ->
failwith
(Fmt.str "nn application without fixed NeuralNetwork or Arguments: %a"
Why3.Pretty.print_term term))
| _ -> Why3.Term.t_map_fold do_simplify (new_index, output_vars) term
in
Why3.Trans.fold_map
(fun task_hd (((index, input_vars) as acc), task) ->
let tdecl = task_hd.task_decl in
match tdecl.td_node with
| Decl
{
d_node =
Dparam ({ ls_args = []; ls_constr = 0; ls_proj = false; _ } as ls);
_;
}
when Language.mem_nn ls ->
(* All neural networks are removed *) (acc, task)
| Decl
{
d_node =
Dparam ({ ls_args = []; ls_constr = 0; ls_proj = false; _ } as ls);
_;
} ->
(* Add meta for input variable declarations. Note that each meta needs
to appear before the corresponding declaration in order to be
leveraged by prover printers. *)
let task =
Why3.Task.add_meta task Meta.meta_input [ MAls ls; MAint index ]
in
let task = Why3.Task.add_tdecl task tdecl in
let index = index + 1 in
let input_vars = Why3.Term.Mls.add ls index input_vars in
((index, input_vars), task)
| Decl { d_node = Dprop (Pgoal, pr, goal); _ } ->
let (_, output_vars), goal = do_simplify (0, []) goal in
let pr = Why3.Decl.create_prsymbol (Why3.Ident.id_clone pr.pr_name) in
let decl = Why3.Decl.create_prop_decl Pgoal pr goal in
(* Again, for each output variable, add the meta first, then its actual
declaration. *)
List.iter output_vars ~f:(fun (out, var) ->
ignore out.old_nn;
Fmt.epr "%a: %a -> %i: %a@." Why3.Pretty.print_ls var
Why3.(Number.print_int_constant Number.full_support)
out.old_index out.new_index
(Fmt.list ~sep:Fmt.comma Why3.Pretty.print_term)
out.old_nn_args);
let task =
List.fold output_vars ~init:task
~f:(fun task ({ new_index; _ }, output_var) ->
let task =
Why3.Task.add_meta task Meta.meta_output
[ MAls output_var; MAint new_index ]
in
let decl = Why3.Decl.create_param_decl output_var in
Why3.Task.add_decl task decl)
in
(acc, Why3.Task.add_decl task decl)
| Decl { d_node = Dprop (_, _, _); _ } ->
(* TODO: check no nn applications *)
(acc, Why3.Task.add_tdecl task tdecl)
| _ -> (acc, Why3.Task.add_tdecl task tdecl))
(0, Why3.Term.Mls.empty) None
(** {2 Old way} *)
(* Creates a list of pairs made of output variables and respective indices in (* Creates a list of pairs made of output variables and respective indices in
the list, for each neural network application to an input vector appearing in the list, for each neural network application to an input vector appearing in
a task. Such a list stands for the resulting output vector of a neural a task. Such a list stands for the resulting output vector of a neural
network application to an input vector (ie, something of the form: nn @@ v). network application to an input vector (ie, something of the form: nn @@ v).
The creation process is memoized wrt terms corresponding to neural network The creation process is memoized wrt terms corresponding to neural network
applications to input vectors. *) applications to input vectors. *)
let create_output_vars = let output_vars =
let rec do_create mt (term : Term.term) = let rec create_output_var mt (term : Why3.Term.term) =
match term.t_node with match term.t_node with
| Term.Tapp (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ]) | Why3.Term.Tapp
when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> ( (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ])
when String.equal ls1.ls_name.id_string (Why3.Ident.op_infix "@@") -> (
match Language.lookup_nn ls2 with match Language.lookup_nn ls2 with
| Some { nn_nb_outputs; nn_ty_elt; _ } -> | Some { nn_nb_outputs; nn_ty_elt; _ } ->
if Term.Mterm.mem term mt if Why3.Term.Mterm.mem term mt
then mt then mt
else else
let output_vars = let output_vars =
List.init nn_nb_outputs ~f:(fun index -> List.init nn_nb_outputs ~f:(fun index ->
(index, Term.create_fsymbol (Ident.id_fresh "y") [] nn_ty_elt)) ( index,
Why3.Term.create_fsymbol (Why3.Ident.id_fresh "y") [] nn_ty_elt
))
in in
Term.Mterm.add term output_vars mt Why3.Term.Mterm.add term output_vars mt
| None -> mt) | None -> mt)
| _ -> Term.t_fold do_create mt term | _ -> Why3.Term.t_fold create_output_var mt term
in in
Trans.fold_decl Why3.Trans.fold_decl
(fun decl mt -> Decl.decl_fold do_create mt decl) (fun decl mt -> Why3.Decl.decl_fold create_output_var mt decl)
Term.Mterm.empty Why3.Term.Mterm.empty
(* Simplifies a task goal exhibiting a vector selection on a neural network (* Simplifies a task goal exhibiting a vector selection on a neural network
application to an input vector (ie, (nn @@ v)[_]) by the corresponding output application to an input vector (ie, (nn @@ v)[_]) by the corresponding output
...@@ -58,26 +176,27 @@ let create_output_vars = ...@@ -58,26 +176,27 @@ let create_output_vars =
all declared, each with a meta that describes the respective index in the all declared, each with a meta that describes the respective index in the
output vector. *) output vector. *)
let simplify_nn_application input_vars output_vars = let simplify_nn_application input_vars output_vars =
let rec do_simplify (term : Term.term) = let rec do_simplify (term : Why3.Term.term) =
match term.t_node with match term.t_node with
| Term.Tapp | Why3.Term.Tapp
( ls_get (* [_] *), ( ls_get (* [_] *),
[ [
({ t_node = Tapp (ls_atat (* @@ *), _); _ } as t1); ({ t_node = Tapp (ls_atat (* @@ *), _); _ } as t1);
({ t_node = Tconst (ConstInt index); _ } as _t2); ({ t_node = Tconst (ConstInt index); _ } as _t2);
] ) ] )
when String.equal ls_get.ls_name.id_string (Ident.op_get "") when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "")
&& String.equal ls_atat.ls_name.id_string (Ident.op_infix "@@") -> ( && String.equal ls_atat.ls_name.id_string (Why3.Ident.op_infix "@@")
match Term.Mterm.find_opt t1 output_vars with -> (
| None -> Term.t_map do_simplify term match Why3.Term.Mterm.find_opt t1 output_vars with
| None -> Why3.Term.t_map do_simplify term
| Some output_vars -> | Some output_vars ->
let index = Number.to_small_integer index in let index = Why3.Number.to_small_integer index in
assert (index < List.length output_vars); assert (index < List.length output_vars);
let ls = Stdlib.List.assoc index output_vars in let ls = Stdlib.List.assoc index output_vars in
Term.t_app_infer ls []) Why3.Term.t_app_infer ls [])
| _ -> Term.t_map do_simplify term | _ -> Why3.Term.t_map do_simplify term
in in
Trans.fold Why3.Trans.fold
(fun task_hd task -> (fun task_hd task ->
match task_hd.task_decl.td_node with match task_hd.task_decl.td_node with
| Decl { d_node = Dparam ls; _ } -> | Decl { d_node = Dparam ls; _ } ->
...@@ -85,38 +204,40 @@ let simplify_nn_application input_vars output_vars = ...@@ -85,38 +204,40 @@ let simplify_nn_application input_vars output_vars =
to appear before the corresponding declaration in order to be to appear before the corresponding declaration in order to be
leveraged by prover printers. *) leveraged by prover printers. *)
let task = let task =
match Term.Mls.find_opt ls input_vars with match Why3.Term.Mls.find_opt ls input_vars with
| None -> task | None -> task
| Some index -> | Some index ->
Task.add_meta task Meta.meta_input [ MAls ls; MAint index ] Why3.Task.add_meta task Meta.meta_input [ MAls ls; MAint index ]
in in
Task.add_tdecl task task_hd.task_decl Why3.Task.add_tdecl task task_hd.task_decl
| Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) -> | Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) ->
let decl = Decl.decl_map do_simplify decl in let decl = Why3.Decl.decl_map do_simplify decl in
let task = let task =
(* Output variables are not declared yet in the task as they are (* Output variables are not declared yet in the task as they are
created on the fly for each (different) neural network application created on the fly for each (different) neural network application
on an input vector. We add here their declarations in the task. *) on an input vector. We add here their declarations in the task. *)
Term.Mterm.fold Why3.Term.Mterm.fold
(fun _t output_vars task -> (fun _t output_vars task ->
(* Again, for each output variable, add the meta first, then its (* Again, for each output variable, add the meta first, then its
actual declaration. *) actual declaration. *)
List.fold output_vars ~init:task List.fold output_vars ~init:task
~f:(fun task (index, output_var) -> ~f:(fun task (index, output_var) ->
let task = let task =
Task.add_meta task Meta.meta_output Why3.Task.add_meta task Meta.meta_output
[ MAls output_var; MAint index ] [ MAls output_var; MAint index ]
in in
let decl = Decl.create_param_decl output_var in let decl = Why3.Decl.create_param_decl output_var in
Task.add_decl task decl)) Why3.Task.add_decl task decl))
output_vars task output_vars task
in in
Task.add_decl task decl Why3.Task.add_decl task decl
| Use _ | Clone _ | Meta _ | Decl _ -> | Use _ | Clone _ | Meta _ | Decl _ ->
Task.add_tdecl task task_hd.task_decl) Why3.Task.add_tdecl task task_hd.task_decl)
None None
let trans = let trans =
Trans.bind Utils.input_vars (fun input_vars -> Why3.Trans.bind Utils.input_terms (function
Trans.bind create_output_vars (fun output_vars -> | Utils.Others -> abstract_nn_term
simplify_nn_application input_vars output_vars)) | Vars input_vars ->
Why3.Trans.bind output_vars (fun output_vars ->
simplify_nn_application input_vars output_vars))
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
open Base open Base
let src = Logs.Src.create "TransformationsUtils" ~doc:"Transformation utils"
let float_of_real_constant rc = let float_of_real_constant rc =
let is_neg, rc = let is_neg, rc =
( Why3.(BigInt.lt rc.Number.rl_real.rv_sig BigInt.zero), ( Why3.(BigInt.lt rc.Number.rl_real.rv_sig BigInt.zero),
...@@ -36,22 +34,26 @@ let float_of_real_constant rc = ...@@ -36,22 +34,26 @@ let float_of_real_constant rc =
(* Collects input variables and their position inside input vectors. Such (* Collects input variables and their position inside input vectors. Such
process is memoized wrt lsymbols corresponding to input vectors. *) process is memoized wrt lsymbols corresponding to input vectors. *)
type position_input_vars = (* Terms forming vectors in input to models. *)
(int Why3.Term.Mls.t type input_terms =
[@printer | Vars of
fun fmt mls -> (int Why3.Term.Mls.t
Why3.( [@printer
Pp.print_iter2 Term.Mls.iter Pp.newline Pp.comma Pretty.print_ls Pp.int fun fmt mls ->
fmt mls)]) Why3.(
Pp.print_iter2 Term.Mls.iter Pp.newline Pp.comma Pretty.print_ls
Pp.int fmt mls)])
(* A map from input variable lsymbols to corresponding positions in input
vectors. *)
| Others (* Generic terms. *)
[@@deriving show] [@@deriving show]
let input_vars : position_input_vars Why3.Trans.trans = let input_terms : input_terms Why3.Trans.trans =
let exception NotInputVariable in
let hls = Why3.Term.Hls.create 13 in let hls = Why3.Term.Hls.create 13 in
let add index mls = function let add index mls = function
| { Why3.Term.t_node = Tapp (ls, []); _ } -> Why3.Term.Mls.add ls index mls | { Why3.Term.t_node = Tapp (ls, []); _ } -> Why3.Term.Mls.add ls index mls
| t -> | _ -> raise NotInputVariable
Logging.code_error ~src (fun m ->
m "Not an input variable: %a" Why3.Pretty.print_term t)
in in
let rec do_collect mls t = let rec do_collect mls t =
match t.Why3.Term.t_node with match t.Why3.Term.t_node with
...@@ -72,8 +74,13 @@ let input_vars : position_input_vars Why3.Trans.trans = ...@@ -72,8 +74,13 @@ let input_vars : position_input_vars Why3.Trans.trans =
| _ -> Why3.Term.t_fold do_collect mls t | _ -> Why3.Term.t_fold do_collect mls t
in in
Why3.Trans.fold_decl Why3.Trans.fold_decl
(fun decl mls -> Why3.Decl.decl_fold do_collect mls decl) (fun decl mls ->
Why3.Term.Mls.empty match mls with
| Vars mls -> (
try Vars (Why3.Decl.decl_fold do_collect mls decl)
with NotInputVariable -> Others)
| Others -> Others)
(Vars Why3.Term.Mls.empty)
(* Collects input feature values (these are typically coming from a data point (* Collects input feature values (these are typically coming from a data point
of a data set). of a data set).
......
...@@ -20,8 +20,13 @@ ...@@ -20,8 +20,13 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
type position_input_vars = int Why3.Term.Mls.t [@@deriving show] (** Terms forming vectors in input to models. *)
(** Map input variable lsymbols to corresponding position in input vectors. *) type input_terms =
| Vars of int Why3.Term.Mls.t
(** A map from input variable lsymbols to corresponding positions in input
vectors. *)
| Others (** Generic terms. *)
[@@deriving show]
type features = Float.t Why3.Term.Mls.t [@@deriving show] type features = Float.t Why3.Term.Mls.t [@@deriving show]
(** Input feature values. *) (** Input feature values. *)
...@@ -34,7 +39,7 @@ and interval = float option * float option [@@deriving show] ...@@ -34,7 +39,7 @@ and interval = float option * float option [@@deriving show]
type label = int [@@deriving show] type label = int [@@deriving show]
(** Output label. *) (** Output label. *)
val input_vars : position_input_vars Why3.Trans.trans val input_terms : input_terms Why3.Trans.trans
val input_features : val input_features :
Why3.Env.env -> vars:Why3.Term.lsymbol list -> features Why3.Trans.trans Why3.Env.env -> vars:Why3.Term.lsymbol list -> features Why3.Trans.trans
......
...@@ -273,7 +273,12 @@ let answer_saver limit config env config_prover ~proof_strategy task = ...@@ -273,7 +273,12 @@ let answer_saver limit config env config_prover ~proof_strategy task =
(svm_filename, Language.svm_abstraction_of_string svm_abstraction) (svm_filename, Language.svm_abstraction_of_string svm_abstraction)
| Some _ -> assert false (* By construction of the meta. *) | Some _ -> assert false (* By construction of the meta. *)
in in
let vars = Why3.Term.Mls.keys (Trans.apply Utils.input_vars task) in let vars =
match Trans.apply Utils.input_terms task with
| Utils.Vars mls -> Why3.Term.Mls.keys mls
| Others ->
Logging.user_error (fun m -> m "Cannot determine input variables")
in
let dataset : Csv.t = let dataset : Csv.t =
let features = Trans.apply (Utils.input_features env ~vars) task in let features = Trans.apply (Utils.input_features env ~vars) task in
let features = let features =
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment