From 451792e3a246a627e2c12f4acbf2caf2f8c9a5af Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Fri, 24 Sep 2021 16:37:06 +0200
Subject: [PATCH] [Transformation] Use Reduction_engine to simplify the formula

---
 src/language.ml        | 15 +++++++++-
 src/language.mli       |  9 +++++-
 src/transformations.ml | 64 +++++++++++++++++++++++++++++++++++++++---
 tests/simple.t         | 14 +++++++--
 4 files changed, 93 insertions(+), 9 deletions(-)

diff --git a/src/language.ml b/src/language.ml
index 33bd7306..34b21be0 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -8,6 +8,13 @@ open Base
 
 (* -- Support for the NNet neural network format. *)
 
+type nnet = {
+  nb_inputs : int;
+  nb_outputs : int;
+  ty_data : Why3.Ty.ty;
+  filename : string;
+}
+
 let loaded_nnets = Why3.Term.Hls.create 10
 
 let lookup_loaded_nnets = Why3.Term.Hls.find_opt loaded_nnets
@@ -34,7 +41,13 @@ let nnet_parser env _ filename _ =
         (List.init header.n_inputs ~f)
         (Ty.ty_tuple (List.init header.n_outputs ~f))
     in
-    Why3.Term.Hls.add loaded_nnets ls_nnet_apply filename;
+    Why3.Term.Hls.add loaded_nnets ls_nnet_apply
+      {
+        filename;
+        nb_inputs = header.n_inputs;
+        nb_outputs = header.n_outputs;
+        ty_data = nnet_input_type;
+      };
     let th_uc =
       Pmodule.add_pdecl ~vc:false th_uc
         (Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply))
diff --git a/src/language.mli b/src/language.mli
index 5a8c5445..23ef26ce 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -1,4 +1,11 @@
-val lookup_loaded_nnets : Why3.Term.lsymbol -> string option
+type nnet = {
+  nb_inputs : int;
+  nb_outputs : int;
+  ty_data : Why3.Ty.ty;
+  filename : string;
+}
+
+val lookup_loaded_nnets : Why3.Term.lsymbol -> nnet option
 (** Return the filename of an nnets Why3 representation *)
 
 val register_nnet_support : unit -> unit
diff --git a/src/transformations.ml b/src/transformations.ml
index 94877c2c..00674859 100644
--- a/src/transformations.ml
+++ b/src/transformations.ml
@@ -6,7 +6,7 @@ let get_input_variables =
     | Why3.Term.Tapp (ls, args) -> (
       match Language.lookup_loaded_nnets ls with
       | None -> acc
-      | Some _name ->
+      | Some _ ->
         let add acc = function
           | { Why3.Term.t_node = Tapp (vs, []); _ } -> Why3.Term.Sls.add vs acc
           | arg ->
@@ -21,11 +21,67 @@ let get_input_variables =
     (fun decl acc -> Why3.Decl.decl_fold aux acc decl)
     Why3.Term.Sls.empty
 
-let simplify_goal _input_variables = Why3.Trans.identity
+let simplify_goal env _input_variables =
+  let rec aux hls (term : Why3.Term.term) =
+    match term.t_node with
+    | Why3.Term.Tapp (ls, _) -> (
+      match Language.lookup_loaded_nnets ls with
+      | None -> Why3.Term.t_map (aux hls) term
+      | Some nnet ->
+        let outputs =
+          List.init nnet.nb_outputs ~f:(fun _ ->
+            let open Why3 in
+            let id = Ident.id_fresh "y" in
+            let ls = Term.create_fsymbol id [] nnet.ty_data in
+            hls := Why3.Decl.create_param_decl ls :: !hls;
+            Term.fs_app ls [] nnet.ty_data)
+        in
+        Why3.Term.t_tuple outputs)
+    | _ -> Why3.Term.t_map (aux hls) term
+  in
+  Why3.Trans.fold
+    (fun task_hd acc ->
+      match task_hd.task_decl.td_node with
+      | Use _ | Clone _ | Meta _ -> Why3.Task.add_tdecl acc task_hd.task_decl
+      | Decl decl ->
+        let hls = ref [] in
+        let map term =
+          let term = aux hls term in
+          if List.is_empty !hls
+          then term
+          else
+            let known =
+              List.fold (List.rev !hls) ~init:task_hd.task_known
+                ~f:Why3.Decl.known_add_decl
+            in
+            let engine =
+              Why3.Reduction_engine.create
+                {
+                  compute_defs = false;
+                  compute_builtin = true;
+                  compute_def_set = Why3.Term.Sls.empty;
+                }
+                env known
+            in
+            Why3.Reduction_engine.normalize ~limit:100 engine
+              Why3.Term.Mvs.empty term
+        in
+        let decl = Why3.Decl.decl_map map decl in
+        let acc =
+          List.fold (List.rev !hls) ~init:acc ~f:(fun acc ls ->
+            Why3.Task.add_decl acc ls)
+        in
+        Why3.Task.add_decl acc decl)
+    None
 
-let caisar_native_prover = Why3.Trans.bind get_input_variables simplify_goal
+let caisar_native_prover env =
+  Why3.Trans.seq
+    [
+      Why3.Trans.bind get_input_variables (simplify_goal env);
+      (* Why3.Simplify_formula.simplify_; *)
+    ]
 
 let init () =
-  Why3.Trans.register_transform
+  Why3.Trans.register_env_transform
     ~desc:"Transformation for provers that support loading neural networks."
     "caisar_native_prover" caisar_native_prover
diff --git a/tests/simple.t b/tests/simple.t
index b608821e..c8b5c220 100644
--- a/tests/simple.t
+++ b/tests/simple.t
@@ -1172,8 +1172,16 @@ Test verify
   
   axiom H1 [@introduced] : lt x1 0.5
   
-  goal G : match nnet_apply x1 x2 x3 x4 x5 with
-    | Tuple5 y1 _ _ _ _ -> lt 0.0 y1 /\ lt y1 0.5
-    end
+  function y : t19
+  
+  function y1 : t19
+  
+  function y2 : t19
+  
+  function y3 : t19
+  
+  function y4 : t19
+  
+  goal G : lt 0.0 y4 /\ lt y4 0.5
   
   end
-- 
GitLab