diff --git a/src/language.ml b/src/language.ml
index 33bd73063a80b3d472aa03f7e4b9fcbbf9945c54..34b21be09adb4860299d9b9d81a7ea55d7dbe409 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -8,6 +8,13 @@ open Base
 
 (* -- Support for the NNet neural network format. *)
 
+type nnet = {
+  nb_inputs : int;
+  nb_outputs : int;
+  ty_data : Why3.Ty.ty;
+  filename : string;
+}
+
 let loaded_nnets = Why3.Term.Hls.create 10
 
 let lookup_loaded_nnets = Why3.Term.Hls.find_opt loaded_nnets
@@ -34,7 +41,13 @@ let nnet_parser env _ filename _ =
         (List.init header.n_inputs ~f)
         (Ty.ty_tuple (List.init header.n_outputs ~f))
     in
-    Why3.Term.Hls.add loaded_nnets ls_nnet_apply filename;
+    Why3.Term.Hls.add loaded_nnets ls_nnet_apply
+      {
+        filename;
+        nb_inputs = header.n_inputs;
+        nb_outputs = header.n_outputs;
+        ty_data = nnet_input_type;
+      };
     let th_uc =
       Pmodule.add_pdecl ~vc:false th_uc
         (Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply))
diff --git a/src/language.mli b/src/language.mli
index 5a8c544502e299eca9ebafb2421b9cf94fe6a573..23ef26ce41c38fd194c55736af8953ab786b1a90 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -1,4 +1,11 @@
-val lookup_loaded_nnets : Why3.Term.lsymbol -> string option
+type nnet = {
+  nb_inputs : int;
+  nb_outputs : int;
+  ty_data : Why3.Ty.ty;
+  filename : string;
+}
+
+val lookup_loaded_nnets : Why3.Term.lsymbol -> nnet option
 (** Return the filename of an nnets Why3 representation *)
 
 val register_nnet_support : unit -> unit
diff --git a/src/transformations.ml b/src/transformations.ml
index 94877c2ca39fefb362bc642538e9f4401fb0a074..0067485960b953ab9c75057e7d2407da46544cd6 100644
--- a/src/transformations.ml
+++ b/src/transformations.ml
@@ -6,7 +6,7 @@ let get_input_variables =
     | Why3.Term.Tapp (ls, args) -> (
       match Language.lookup_loaded_nnets ls with
       | None -> acc
-      | Some _name ->
+      | Some _ ->
         let add acc = function
           | { Why3.Term.t_node = Tapp (vs, []); _ } -> Why3.Term.Sls.add vs acc
           | arg ->
@@ -21,11 +21,67 @@ let get_input_variables =
     (fun decl acc -> Why3.Decl.decl_fold aux acc decl)
     Why3.Term.Sls.empty
 
-let simplify_goal _input_variables = Why3.Trans.identity
+let simplify_goal env _input_variables =
+  let rec aux hls (term : Why3.Term.term) =
+    match term.t_node with
+    | Why3.Term.Tapp (ls, _) -> (
+      match Language.lookup_loaded_nnets ls with
+      | None -> Why3.Term.t_map (aux hls) term
+      | Some nnet ->
+        let outputs =
+          List.init nnet.nb_outputs ~f:(fun _ ->
+            let open Why3 in
+            let id = Ident.id_fresh "y" in
+            let ls = Term.create_fsymbol id [] nnet.ty_data in
+            hls := Why3.Decl.create_param_decl ls :: !hls;
+            Term.fs_app ls [] nnet.ty_data)
+        in
+        Why3.Term.t_tuple outputs)
+    | _ -> Why3.Term.t_map (aux hls) term
+  in
+  Why3.Trans.fold
+    (fun task_hd acc ->
+      match task_hd.task_decl.td_node with
+      | Use _ | Clone _ | Meta _ -> Why3.Task.add_tdecl acc task_hd.task_decl
+      | Decl decl ->
+        let hls = ref [] in
+        let map term =
+          let term = aux hls term in
+          if List.is_empty !hls
+          then term
+          else
+            let known =
+              List.fold (List.rev !hls) ~init:task_hd.task_known
+                ~f:Why3.Decl.known_add_decl
+            in
+            let engine =
+              Why3.Reduction_engine.create
+                {
+                  compute_defs = false;
+                  compute_builtin = true;
+                  compute_def_set = Why3.Term.Sls.empty;
+                }
+                env known
+            in
+            Why3.Reduction_engine.normalize ~limit:100 engine
+              Why3.Term.Mvs.empty term
+        in
+        let decl = Why3.Decl.decl_map map decl in
+        let acc =
+          List.fold (List.rev !hls) ~init:acc ~f:(fun acc ls ->
+            Why3.Task.add_decl acc ls)
+        in
+        Why3.Task.add_decl acc decl)
+    None
 
-let caisar_native_prover = Why3.Trans.bind get_input_variables simplify_goal
+let caisar_native_prover env =
+  Why3.Trans.seq
+    [
+      Why3.Trans.bind get_input_variables (simplify_goal env);
+      (* Why3.Simplify_formula.simplify_; *)
+    ]
 
 let init () =
-  Why3.Trans.register_transform
+  Why3.Trans.register_env_transform
     ~desc:"Transformation for provers that support loading neural networks."
     "caisar_native_prover" caisar_native_prover
diff --git a/tests/simple.t b/tests/simple.t
index b608821eed730a611d876e8e6d9c027290516a8a..c8b5c220e5bccac74d6f1db27859d31285ec3235 100644
--- a/tests/simple.t
+++ b/tests/simple.t
@@ -1172,8 +1172,16 @@ Test verify
   
   axiom H1 [@introduced] : lt x1 0.5
   
-  goal G : match nnet_apply x1 x2 x3 x4 x5 with
-    | Tuple5 y1 _ _ _ _ -> lt 0.0 y1 /\ lt y1 0.5
-    end
+  function y : t19
+  
+  function y1 : t19
+  
+  function y2 : t19
+  
+  function y3 : t19
+  
+  function y4 : t19
+  
+  goal G : lt 0.0 y4 /\ lt y4 0.5
   
   end