From 94016493e0d8d514b3da70e382d91713ecb692e3 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Fri, 7 Apr 2023 11:31:09 +0200
Subject: [PATCH] [interpretation] Extension of current transformations.

---
 src/interpretation.ml                    |  2 +
 src/proof_strategy.ml                    | 33 ++++++++--
 src/transformations/native_nn_prover.ml  | 84 +++++++++++++++++++++++-
 src/transformations/native_nn_prover.mli |  3 +-
 src/transformations/utils.ml             | 37 ++++++++---
 src/transformations/utils.mli            |  6 +-
 6 files changed, 146 insertions(+), 19 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index d812197..b5d97f8 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -386,6 +386,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       term (term_of_caisar_op engine caisar_op ty)
     | _ -> invalid_arg (error_message ls)
   in
+
   [
     ( [ "interpretation" ],
       "Vector",
@@ -409,6 +410,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       [
         ([ "read_classifier" ], None, read_classifier);
         ([ Ident.op_infix "@@" ], None, apply_classifier);
+        ([ Ident.op_infix "%%" ], None, apply_classifier);
       ] );
     ( [ "interpretation" ],
       "Dataset",
diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml
index 041bee1..f4bccd1 100644
--- a/src/proof_strategy.ml
+++ b/src/proof_strategy.ml
@@ -22,17 +22,36 @@
 
 open Why3
 
-let do_apply_prover trans task =
+let apply_classic_prover env task =
   let nb = Trans.apply Utils.count_nn_apply task in
   match nb with
   | 0 -> task
-  | 1 -> Trans.apply trans task
+  | 1 -> Trans.apply (Nn2smt.trans env) task
   | _ ->
     invalid_arg "Two or more neural network applications are not supported yet"
 
-let apply_classic_prover env task = do_apply_prover (Nn2smt.trans env) task
-
 let apply_native_nn_prover env task =
-  do_apply_prover
-    (Trans.seq [ Introduction.introduce_premises; Native_nn_prover.trans env ])
-    task
+  let nb_nn_apply = Trans.apply Utils.count_nn_apply task in
+  let nb_nn_classifiers = Trans.apply Utils.count_nn_classifiers task in
+  match (nb_nn_apply, nb_nn_classifiers) with
+  | 0, 0 -> task
+  | 1, 0 ->
+    Trans.(
+      apply
+        (seq
+           [
+             Introduction.introduce_premises;
+             Native_nn_prover.trans_nn_apply env;
+           ]))
+      task
+  | 0, 1 ->
+    Trans.(
+      apply
+        (seq
+           [
+             Introduction.introduce_premises;
+             Native_nn_prover.trans_nn_classifier env;
+           ]))
+      task
+  | _ ->
+    invalid_arg "Two or more neural network applications are not supported yet"
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index 3edcb00..bc705a6 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -92,5 +92,87 @@ let simplify_goal env input_variables =
         Task.add_decl acc decl)
     None
 
-let trans env =
+let trans_nn_apply env =
+  Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
+
+(* Create logic symbols for output variables and simplify the formula. *)
+let simplify_goal _env input_variables =
+  let rec aux hls (term : Term.term) =
+    match term.t_node with
+    | Term.Tapp
+        ( ls_vget,
+          [
+            ({
+               t_node =
+                 Tapp
+                   ( ls_apply_classifier,
+                     [
+                       { t_node = Tapp (ls_nn_classifier, _); _ };
+                       { t_node = Tapp (ls_vector, _); _ };
+                     ] );
+               _;
+             } as _t1);
+            ({ t_node = Tconst (ConstInt i); _ } as _t2);
+          ] )
+      when String.equal ls_vget.ls_name.id_string (Ident.op_get "")
+           && String.equal ls_apply_classifier.ls_name.id_string
+                (Ident.op_infix "%%") -> (
+      match
+        ( Language.lookup_nn_classifier ls_nn_classifier,
+          Language.lookup_vector ls_vector )
+      with
+      | Some nn, Some _ ->
+        let index = Number.to_small_integer i in
+        let hout =
+          Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout ->
+            let ls =
+              let id = Ident.id_fresh "y" in
+              Term.create_fsymbol id [] nn.nn_ty_elt
+            in
+            match hout with
+            | None ->
+              let hout = Hashtbl.create (module Int) in
+              Hashtbl.add_exn hout ~key:index ~data:ls;
+              hout
+            | Some hout ->
+              Hashtbl.update hout index ~f:(fun lsout ->
+                match lsout with
+                | None ->
+                  Hashtbl.add_exn hout ~key:index ~data:ls;
+                  ls
+                | Some ls -> ls);
+              hout)
+        in
+        let ls_output = Hashtbl.find_exn hout index in
+        Term.fs_app ls_output [] nn.nn_ty_elt
+      | _ -> Term.t_map (aux hls) term)
+    | _ -> Term.t_map (aux hls) term
+  in
+  let htbl = Hashtbl.create (module String) in
+  Trans.fold
+    (fun task_hd acc ->
+      match task_hd.task_decl.td_node with
+      | Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl
+      | Decl { d_node = Dparam ls; _ } -> (
+        let task = Task.add_tdecl acc task_hd.task_decl in
+        match Term.Mls.find_opt ls input_variables with
+        | None -> task
+        | Some pos -> Task.add_meta task Utils.meta_input [ MAls ls; MAint pos ]
+        )
+      | Decl decl ->
+        let decl = Decl.decl_map (fun term -> aux htbl term) decl in
+        let acc =
+          Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc ->
+            let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in
+            Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc ->
+              let acc =
+                let decl = Decl.create_param_decl data in
+                Task.add_decl acc decl
+              in
+              Task.add_meta acc Utils.meta_output [ MAls data; MAint key ]))
+        in
+        Task.add_decl acc decl)
+    None
+
+let trans_nn_classifier env =
   Trans.seq [ Trans.bind Utils.get_input_variables (simplify_goal env) ]
diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli
index 936ff49..84ea944 100644
--- a/src/transformations/native_nn_prover.mli
+++ b/src/transformations/native_nn_prover.mli
@@ -22,4 +22,5 @@
 
 open Why3
 
-val trans : Env.env -> Task.task Trans.trans
+val trans_nn_apply : Env.env -> Task.task Trans.trans
+val trans_nn_classifier : Env.env -> Task.task Trans.trans
diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml
index 1cfbdda..a2279ec 100644
--- a/src/transformations/utils.ml
+++ b/src/transformations/utils.ml
@@ -35,21 +35,40 @@ let count_nn_apply =
   in
   Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
 
+let count_nn_classifiers =
+  let rec aux acc (term : Term.term) =
+    let acc = Term.t_fold aux acc term in
+    match term.t_node with
+    | Term.Tapp (ls, _) -> (
+      match Language.lookup_nn_classifier ls with
+      | None -> acc
+      | Some _ -> acc + 1)
+    | _ -> acc
+  in
+  Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) 0
+
 let get_input_variables =
+  let add i acc = function
+    | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
+    | arg ->
+      invalid_arg
+        (Fmt.str "No direct variable in application: %a" Pretty.print_term arg)
+  in
   let rec aux acc (term : Term.term) =
     match term.t_node with
+    | Term.Tapp
+        ( { ls_name; _ },
+          [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
+      when String.equal ls_name.id_string (Ident.op_infix "%%") -> (
+      match (Language.lookup_nn_classifier ls1, Language.lookup_vector ls2) with
+      | Some { nn_inputs; _ }, Some n ->
+        assert (nn_inputs = n && n = List.length args);
+        List.foldi ~init:acc ~f:add args
+      | _ -> acc)
     | Term.Tapp (ls, args) -> (
       match Language.lookup_loaded_nets ls with
       | None -> acc
-      | Some _ ->
-        let add i acc = function
-          | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc
-          | arg ->
-            invalid_arg
-              (Fmt.str "No direct variable in application: %a" Pretty.print_term
-                 arg)
-        in
-        List.foldi ~init:acc ~f:add args)
+      | Some _ -> List.foldi ~init:acc ~f:add args)
     | _ -> Term.t_fold aux acc term
   in
   Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty
diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli
index 12b8c13..8d54c04 100644
--- a/src/transformations/utils.mli
+++ b/src/transformations/utils.mli
@@ -25,8 +25,12 @@ open Why3
 val count_nn_apply : int Trans.trans
 (** Count the number of applications of [nn_apply]. *)
 
+val count_nn_classifiers : int Trans.trans
+(** Count the number of applications of a NN classifier. *)
+
 val get_input_variables : int Term.Mls.t Trans.trans
-(** Retrieve the input variables appearing as arguments of [nn_apply]. *)
+(** Retrieve the input variables appearing as arguments of [nn_apply] or a NN
+    classifier. *)
 
 val meta_input : Theory.meta
 (** Indicate the input position. *)
-- 
GitLab