Skip to content
Snippets Groups Projects
Commit 0af3edb2 authored by Michele Alberti's avatar Michele Alberti
Browse files

[trans] Separate utility functions for gathering input symbols.

parent 0bebd680
No related branches found
No related tags found
No related merge requests found
......@@ -95,6 +95,28 @@ let simplify_goal env input_variables =
let trans_nn_apply env =
Trans.bind Utils.get_input_variables (simplify_goal env)
let get_input_variables =
let add i acc = function
| { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
| arg ->
invalid_arg
(Fmt.str "No direct variable in application: %a" Pretty.print_term arg)
in
let rec aux acc (term : Term.term) =
match term.t_node with
| Term.Tapp
( { ls_name; _ },
[ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
| Some { nn_inputs; _ }, Some n ->
assert (nn_inputs = n && n = List.length args);
List.foldi ~init:acc ~f:add args
| _ -> acc)
| _ -> Term.t_fold aux acc term
in
Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty
(* Create logic symbols for output variables and simplify the formula. *)
let simplify_goal _env input_variables =
let rec aux hls (term : Term.term) =
......@@ -125,19 +147,21 @@ let simplify_goal _env input_variables =
let index = Number.to_small_integer i in
let hout =
Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout ->
let ls =
let create_ls_output () =
let id = Ident.id_fresh "y" in
Term.create_fsymbol id [] nn.nn_ty_elt
in
match hout with
| None ->
let hout = Hashtbl.create (module Int) in
let ls = create_ls_output () in
Hashtbl.add_exn hout ~key:index ~data:ls;
hout
| Some hout ->
Hashtbl.update hout index ~f:(fun lsout ->
match lsout with
| None ->
let ls = create_ls_output () in
Hashtbl.add_exn hout ~key:index ~data:ls;
ls
| Some ls -> ls);
......@@ -174,5 +198,4 @@ let simplify_goal _env input_variables =
Task.add_decl acc decl)
None
let trans_nn_classifier env =
Trans.bind Utils.get_input_variables (simplify_goal env)
let trans_nn_classifier env = Trans.bind get_input_variables (simplify_goal env)
......@@ -60,15 +60,6 @@ let get_input_variables =
in
let rec aux acc (term : Term.term) =
match term.t_node with
| Term.Tapp
( { ls_name; _ },
[ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
| Some { nn_inputs; _ }, Some n ->
assert (nn_inputs = n && n = List.length args);
List.foldi ~init:acc ~f:add args
| _ -> acc)
| Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nets ls with
| None -> acc
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment