diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 177220b93172bdb6d8daa30af36807f5a3bac3fb..9bb96dd31c1ae22eb8e76203e4ba47d5e1bd67e4 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -95,6 +95,28 @@ let simplify_goal env input_variables =
 let trans_nn_apply env =
   Trans.bind Utils.get_input_variables (simplify_goal env)
 
+let get_input_variables =
+  let add i acc = function
+    | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
+    | arg ->
+      invalid_arg
+        (Fmt.str "No direct variable in application: %a" Pretty.print_term arg)
+  in
+  let rec aux acc (term : Term.term) =
+    match term.t_node with
+    | Term.Tapp
+        ( { ls_name; _ },
+          [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
+      when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
+      match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
+      | Some { nn_inputs; _ }, Some n ->
+        assert (nn_inputs = n && n = List.length args);
+        List.foldi ~init:acc ~f:add args
+      | _ -> acc)
+    | _ -> Term.t_fold aux acc term
+  in
+  Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty
+
 (* Create logic symbols for output variables and simplify the formula. *)
 let simplify_goal _env input_variables =
   let rec aux hls (term : Term.term) =
@@ -125,19 +147,21 @@ let simplify_goal _env input_variables =
         let index = Number.to_small_integer i in
         let hout =
           Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout ->
-            let ls =
+            let create_ls_output () =
               let id = Ident.id_fresh "y" in
               Term.create_fsymbol id [] nn.nn_ty_elt
             in
             match hout with
             | None ->
               let hout = Hashtbl.create (module Int) in
+              let ls = create_ls_output () in
               Hashtbl.add_exn hout ~key:index ~data:ls;
               hout
             | Some hout ->
               Hashtbl.update hout index ~f:(fun lsout ->
                 match lsout with
                 | None ->
+                  let ls = create_ls_output () in
                   Hashtbl.add_exn hout ~key:index ~data:ls;
                   ls
                 | Some ls -> ls);
@@ -174,5 +198,4 @@ let simplify_goal _env input_variables =
         Task.add_decl acc decl)
     None
 
-let trans_nn_classifier env =
-  Trans.bind Utils.get_input_variables (simplify_goal env)
+let trans_nn_classifier env = Trans.bind get_input_variables (simplify_goal env)
diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml
index 26966f6c5c2384e1c292c79bf483dea134d958c5..892ad38e3ad8c6876d12bc05a132e40e003c710b 100644
--- a/src/transformations/utils.ml
+++ b/src/transformations/utils.ml
@@ -60,15 +60,6 @@ let get_input_variables =
   in
   let rec aux acc (term : Term.term) =
     match term.t_node with
-    | Term.Tapp
-        ( { ls_name; _ },
-          [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
-      when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
-      match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
-      | Some { nn_inputs; _ }, Some n ->
-        assert (nn_inputs = n && n = List.length args);
-        List.foldi ~init:acc ~f:add args
-      | _ -> acc)
     | Term.Tapp (ls, args) -> (
       match Language.lookup_loaded_nets ls with
       | None -> acc