diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 177220b93172bdb6d8daa30af36807f5a3bac3fb..9bb96dd31c1ae22eb8e76203e4ba47d5e1bd67e4 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -95,6 +95,28 @@ let simplify_goal env input_variables = let trans_nn_apply env = Trans.bind Utils.get_input_variables (simplify_goal env) +let get_input_variables = + let add i acc = function + | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc + | arg -> + invalid_arg + (Fmt.str "No direct variable in application: %a" Pretty.print_term arg) + in + let rec aux acc (term : Term.term) = + match term.t_node with + | Term.Tapp + ( { ls_name; _ }, + [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) + when String.equal ls_name.id_string (Ident.op_infix "%%") -> ( + match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with + | Some { nn_inputs; _ }, Some n -> + assert (nn_inputs = n && n = List.length args); + List.foldi ~init:acc ~f:add args + | _ -> acc) + | _ -> Term.t_fold aux acc term + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty + (* Create logic symbols for output variables and simplify the formula. *) let simplify_goal _env input_variables = let rec aux hls (term : Term.term) = @@ -125,19 +147,21 @@ let simplify_goal _env input_variables = let index = Number.to_small_integer i in let hout = Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout -> - let ls = + let create_ls_output () = let id = Ident.id_fresh "y" in Term.create_fsymbol id [] nn.nn_ty_elt in match hout with | None -> let hout = Hashtbl.create (module Int) in + let ls = create_ls_output () in Hashtbl.add_exn hout ~key:index ~data:ls; hout | Some hout -> Hashtbl.update hout index ~f:(fun lsout -> match lsout with | None -> + let ls = create_ls_output () in Hashtbl.add_exn hout ~key:index ~data:ls; ls | Some ls -> ls); @@ -174,5 +198,4 @@ let simplify_goal _env input_variables = Task.add_decl acc decl) None -let trans_nn_classifier env = - Trans.bind Utils.get_input_variables (simplify_goal env) +let trans_nn_classifier env = Trans.bind get_input_variables (simplify_goal env) diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml index 26966f6c5c2384e1c292c79bf483dea134d958c5..892ad38e3ad8c6876d12bc05a132e40e003c710b 100644 --- a/src/transformations/utils.ml +++ b/src/transformations/utils.ml @@ -60,15 +60,6 @@ let get_input_variables = in let rec aux acc (term : Term.term) = match term.t_node with - | Term.Tapp - ( { ls_name; _ }, - [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) - when String.equal ls_name.id_string (Ident.op_infix "%%") -> ( - match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with - | Some { nn_inputs; _ }, Some n -> - assert (nn_inputs = n && n = List.length args); - List.foldi ~init:acc ~f:add args - | _ -> acc) | Term.Tapp (ls, args) -> ( match Language.lookup_loaded_nets ls with | None -> acc