From 8ca73d5f2924ddb81804e9afeacd4941c2ad3ec2 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 31 May 2023 16:36:43 +0200
Subject: [PATCH] [trans] Better code style.

---
 src/transformations/native_nn_prover.ml | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index a5a1d23..8d3d943 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -42,11 +42,9 @@ let collect_input_vars =
           ] )
       when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> (
       match (Language.lookup_nn ls2, Language.lookup_vector ls3) with
-      | Some { nn_nb_inputs; _ }, Some vector_length -> (
+      | Some { nn_nb_inputs; _ }, Some vector_length ->
         assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
-        match Term.Hls.find_opt hls ls3 with
-        | None -> List.foldi ~init:mls ~f:add tl
-        | Some _ -> mls)
+        if Term.Hls.mem hls ls3 then mls else List.foldi ~init:mls ~f:add tl
       | _, _ -> mls)
     | _ -> Term.t_fold do_collect mls term
   in
@@ -66,17 +64,17 @@ let create_output_vars =
     | Term.Tapp (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ])
       when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> (
       match Language.lookup_nn ls2 with
-      | Some { nn_nb_outputs; nn_ty_elt; _ } -> (
-        match Term.Mterm.find_opt term mt with
-        | None ->
+      | Some { nn_nb_outputs; nn_ty_elt; _ } ->
+        if Term.Mterm.mem term mt
+        then mt
+        else
           let output_vars =
             List.init nn_nb_outputs ~f:(fun index ->
               ( index,
                 Term.create_fsymbol (Ident.id_fresh "caisar_y") [] nn_ty_elt ))
           in
           Term.Mterm.add term output_vars mt
-        | Some _ -> mt)
-      | _ -> mt)
+      | None -> mt)
     | _ -> Term.t_fold do_create mt term
   in
   Trans.fold_decl
-- 
GitLab