From 472b3bc77587543a006df53ee93e8cf89789f156 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Wed, 3 Apr 2024 16:01:05 +0200
Subject: [PATCH] [Interpretation] Convert `nn@@vv` into `mk_vector
 (nn@@vv[0])` ...

   So an invariant for normal formula is that application of nn are always of the shape `nn@@vv[cst]`
---
 src/interpretation/interpreter_theory.ml | 102 +++++++++++++----------
 src/interpretation/interpreter_types.ml  |   8 ++
 src/interpretation/interpreter_types.mli |   2 +
 tests/interpretation_fail.t              |   5 +-
 4 files changed, 70 insertions(+), 47 deletions(-)

diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml
index 1787610..9690498 100644
--- a/src/interpretation/interpreter_theory.ml
+++ b/src/interpretation/interpreter_theory.ml
@@ -35,13 +35,17 @@ let fail_on_unexpected_argument ls =
 module Vector = struct
   let (get : _ IRE.builtin) =
    fun engine ls vl ty ->
+    let th_model = Symbols.Model.create (IRE.user_env engine).ITypes.env in
     match vl with
     | [
      Term
-       ({ t_node = Tapp (_ (* @@ *), [ { t_node = Tapp (ls, _); _ }; _ ]); _ }
-       as _t1);
+       ({
+          t_node = Tapp (ls_atat (* @@ *), [ { t_node = Tapp (ls, _); _ }; _ ]);
+          _;
+        } as _t1);
      Term ({ t_node = Tconst (ConstInt i); _ } as t2);
-    ] -> (
+    ]
+      when Why3.Term.ls_equal ls_atat th_model.atat -> (
       let i = Why3.Number.to_small_integer i in
       if i < 0
       then
@@ -304,50 +308,60 @@ module NN = struct
     | _ -> fail_on_unexpected_argument ls
 
   let apply : _ IRE.builtin =
-   fun engine ls vl _ty ->
+   fun engine ls vl ty ->
     match vl with
-    | [
-     Term ({ t_node = Tapp (ls1, []); _ } as t1);
-     Term ({ t_node = Tapp (ls2, tl2); _ } as t2);
-    ] -> (
-      match (ITypes.op_of_ls engine ls1, ITypes.op_of_ls engine ls2) with
-      | Model (NN (nn, _)), Vector v ->
-        let nn =
-          match Language.lookup_nn nn with
-          | None ->
-            Logging.code_error ~src:Logging.src_interpret_goal (fun m ->
-              m "Cannot find neural network model from lsymbol %a"
-                Why3.Pretty.print_ls nn)
-          | Some nn -> nn
-        in
-        let length_v =
-          match Language.lookup_vector v with
-          | None ->
-            Logging.code_error ~src:Logging.src_interpret_goal (fun m ->
-              m "Cannot find vector from lsymbol %a" Why3.Pretty.print_ls v)
-          | Some n ->
-            if List.length tl2 <> n
-            then
+    | [ Term t1; Term t2 ] -> (
+      match ITypes.op_of_term engine t1 with
+      | Some (Model (NN (nn, _)), []) ->
+        let { ITypes.env; _ } = IRE.user_env engine in
+        let nn = Option.value_exn (Language.lookup_nn nn) in
+        (match ITypes.op_of_term engine t2 with
+        | Some (Vector v, tl2) ->
+          let length_v =
+            match Language.lookup_vector v with
+            | None ->
               Logging.code_error ~src:Logging.src_interpret_goal (fun m ->
-                m
-                  "Mismatch between (container) vector length and number of \
-                   (contained) input variables.");
-            n
+                m "Cannot find vector from lsymbol %a" Why3.Pretty.print_ls v)
+            | Some n ->
+              if List.length tl2 <> n
+              then
+                Logging.code_error ~src:Logging.src_interpret_goal (fun m ->
+                  m
+                    "Mismatch between (container) vector length and number of \
+                     (contained) input variables.");
+              n
+          in
+          if nn.nn_nb_inputs <> length_v
+          then
+            Logging.user_error ?loc:t2.t_loc (fun m ->
+              m
+                "Unexpected vector of length %d in input to neural network \
+                 model '%s',@ which expects input vectors of length %d"
+                length_v nn.nn_filename nn.nn_nb_inputs)
+        | _ ->
+          Logging.user_error ?loc:t1.t_loc (fun m ->
+            m "Unexpected neural network model application: %a"
+              Why3.Pretty.print_term t2));
+        let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in
+        let get =
+          Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ]
         in
-        if nn.nn_nb_inputs <> length_v
-        then
-          Logging.user_error ?loc:t2.t_loc (fun m ->
-            m
-              "Unexpected vector of length %d in input to neural network model \
-               '%s',@ which expects input vectors of length %d"
-              length_v nn.nn_filename nn.nn_nb_inputs)
-        else IRE.reconstruct_term ()
-      | Model (SVM _), _ ->
-        (* Should be already catched by the Why3 typing. *)
-        assert false
-      | _, _ ->
-        Logging.user_error ?loc:t1.t_loc (fun m ->
-          m "Unexpected neural network model application"))
+        let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in
+        let args =
+          List.init nn.nn_nb_outputs ~f:(fun i ->
+            ( Why3.Term.fs_app get
+                [
+                  t0;
+                  Why3.Term.t_const
+                    (Why3.Constant.int_const_of_int i)
+                    Why3.Ty.ty_int;
+                ]
+                nn.nn_ty_elt,
+              nn.nn_ty_elt ))
+        in
+        let op = ITypes.Vector (Language.create_vector env nn.nn_nb_outputs) in
+        IRE.value_term (ITypes.term_of_op ~args engine op ty)
+      | _ -> IRE.reconstruct_term ())
     | _ -> fail_on_unexpected_argument ls
 
   let builtins : _ IRE.built_in_theories =
diff --git a/src/interpretation/interpreter_types.ml b/src/interpretation/interpreter_types.ml
index d538870..0e015b9 100644
--- a/src/interpretation/interpreter_types.ml
+++ b/src/interpretation/interpreter_types.ml
@@ -103,6 +103,14 @@ let term_of_op ?(args = []) engine interpreter_op ty =
   let t_args, ty_args = List.unzip args in
   Why3.Term.t_app_infer (ls_of_op engine interpreter_op ty_args ty) t_args
 
+let op_of_term engine t =
+  match t.Why3.Term.t_node with
+  | Tapp (ls, args) -> (
+    match op_of_ls engine ls with
+    | exception Stdlib.Not_found -> None
+    | v -> Some (v, args))
+  | _ -> None
+
 let interpreter_env ~cwd env =
   {
     ls_of_op = Hashtbl.Poly.create ();
diff --git a/src/interpretation/interpreter_types.mli b/src/interpretation/interpreter_types.mli
index 8eb1ce8..2481c48 100644
--- a/src/interpretation/interpreter_types.mli
+++ b/src/interpretation/interpreter_types.mli
@@ -54,6 +54,8 @@ type interpreter_env = private {
 
 val op_of_ls : interpreter_env IRE.engine -> Why3.Term.lsymbol -> interpreter_op
 
+val op_of_term : interpreter_env IRE.engine -> Why3.Term.term -> (interpreter_op * Why3.Term.term list) option
+
 val term_of_op :
   ?args:(Why3.Term.term * Why3.Ty.ty) list ->
   interpreter_env IRE.engine ->
diff --git a/tests/interpretation_fail.t b/tests/interpretation_fail.t
index a389622..7d68cd9 100644
--- a/tests/interpretation_fail.t
+++ b/tests/interpretation_fail.t
@@ -146,7 +146,7 @@ Test interpret fail
   > EOF
 
   $ caisar verify --prover nnenum file.mlw
-  [ERROR] "file.mlw", line 12, characters 24-26:
+  [ERROR] "file.mlw", line 12, characters 14-23:
           Index constant 10 is out-of-bounds [0,4]
 
   $ cat > file.mlw <<EOF
@@ -183,5 +183,4 @@ Test interpret fail
   > EOF
 
   $ caisar verify --prover SAVer file.mlw
-  [ERROR] "file.mlw", line 10, characters 35-36:
-          Index constant 4 is out-of-bounds [0,1]
+  [ERROR] Cannot find feature for input variable 'x'
-- 
GitLab