From 08de5dc6944ac5f90d66a6748fc6b8b7887c5711 Mon Sep 17 00:00:00 2001
From: Maxime Jacquemin <maxime.jacquemin@cea.fr>
Date: Fri, 2 Feb 2024 10:56:06 +0100
Subject: [PATCH] [Kernel] Simpler and directly memoized Matrix powers

---
 src/kernel_services/analysis/filter/linear.ml | 32 ++++++++-----------
 .../analysis/filter/linear.mli                |  3 +-
 .../analysis/filter/linear_filter.ml          | 22 +++----------
 3 files changed, 19 insertions(+), 38 deletions(-)

diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml
index 629e482e1c9..1dd2fc44fd0 100644
--- a/src/kernel_services/analysis/filter/linear.ml
+++ b/src/kernel_services/analysis/filter/linear.ml
@@ -147,25 +147,19 @@ module Space (Field : Field.S) = struct
       let max i res = Field.max res (row i (M m) |> sum) in
       Finite.for_each max m.rows Field.zero
 
-    let rec fast_power target steps matrix exponent =
-      let steps' = (matrix, exponent) :: steps in
-      let exponent' = Stdlib.(exponent * 2) in
-      if exponent' > target then if exponent <= target then steps' else steps
-      else fast_power target steps' (matrix * matrix) exponent'
-
-    let rec refine target = function
-      | [] -> assert false
-      | [ (matrix, _) ] -> matrix
-      | (matrix, exponent) :: (matrix', exponent') :: previous ->
-        let exponent'' = Stdlib.(exponent + exponent') in
-        if exponent'' > target
-        then refine target ((matrix, exponent) :: previous)
-        else refine target (((matrix * matrix'), exponent'') :: previous)
-
-    let power (type n) (M matrix : (n, n) matrix) exponent : (n, n) matrix =
-      if exponent < 0 then raise (Invalid_argument "negative exponent")
-      else if Stdlib.(exponent = 0) then id matrix.rows
-      else fast_power exponent [] (M matrix) 1 |> refine exponent
+    let power (type n) (M m : (n, n) matrix) : int -> (n, n) matrix =
+      let n = dimensions (M m) |> fst in
+      let cache = Datatype.Int.Hashtbl.create 17 in
+      let find i = Datatype.Int.Hashtbl.find_opt cache i in
+      let save i v = Datatype.Int.Hashtbl.add cache i v ; v in
+      let rec pow e =
+        if e < 0 then raise (Invalid_argument "negative exponent") ;
+        match find e with
+        | Some r -> r
+        | None when Stdlib.(e = 0) -> id n
+        | None when Stdlib.(e = 1) -> M m
+        | None -> let h = pow (e / 2) in save e (pow (e mod 2) * h * h)
+      in pow
 
   end
 
diff --git a/src/kernel_services/analysis/filter/linear.mli b/src/kernel_services/analysis/filter/linear.mli
index 3ab99f29e66..173b2c9bb4d 100644
--- a/src/kernel_services/analysis/filter/linear.mli
+++ b/src/kernel_services/analysis/filter/linear.mli
@@ -52,7 +52,8 @@ module Space (Field : Field.S) : sig
     val dimensions : ('m, 'n) matrix -> 'm nat * 'n nat
     val ( + ) : ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix
     val ( * ) : ('n, 'm) matrix -> ('m, 'p) matrix -> ('n, 'p) matrix
-    val power : ('n, 'n) matrix -> int -> ('n, 'n) matrix
+    (* Memoized, instantiate first on a matrix and then use it *)
+    val power : ('n, 'n) matrix -> (int -> ('n, 'n) matrix)
   end
 
 end
diff --git a/src/kernel_services/analysis/filter/linear_filter.ml b/src/kernel_services/analysis/filter/linear_filter.ml
index f216a5f5ca9..197357393b3 100644
--- a/src/kernel_services/analysis/filter/linear_filter.ml
+++ b/src/kernel_services/analysis/filter/linear_filter.ml
@@ -77,12 +77,6 @@ module Make (Field : Field.S) = struct
 
 
 
-  let memoized_powers m =
-    let cache = Datatype.Int.Hashtbl.create 17 in
-    let find i = Datatype.Int.Hashtbl.find cache i in
-    let save i v = Datatype.Int.Hashtbl.add cache i v ; v in
-    fun i -> try find i with Not_found -> save i (Matrix.power m i)
-
   let check_convergence matrix =
     let norm = Matrix.norm matrix in
     if Field.(norm < one) then Some norm else None
@@ -90,18 +84,18 @@ module Make (Field : Field.S) = struct
   type ('n, 'm) compute = ('n, 'm) filter -> int -> 'n invariant option
   let invariant : type n m. (n, m) compute = fun (Filter f) e ->
     let open Option.Operators in
+    let state = Matrix.power f.state in
     let measure = Vector.norm f.measure in
-    let powers = memoized_powers f.state in
     let order, _ = Matrix.dimensions f.input in
     let base i = Vector.base i order |> Matrix.transpose in
     let* StrictlyPositive exponent = Nat.of_strictly_positive_int e in
-    let+ spectral = powers (Nat.to_int exponent) |> check_convergence in
+    let+ spectral = state (Nat.to_int exponent) |> check_convergence in
     (* Computation of the inputs contribution for the i-th state dimension *)
-    let input base e = Matrix.(base * powers e * f.input |> norm) in
+    let input base e = Matrix.(base * state e * f.input |> norm) in
     let add_input base e res = Field.(res + input base (Finite.to_int e)) in
     let input i = Finite.for_each (base i |> add_input) exponent Field.zero in
     (* Computation of the center contribution for the i-th state dimension *)
-    let center e = Matrix.(powers e * f.center) in
+    let center e = Matrix.(state e * f.center) in
     let add_center e res = Matrix.(res + center (Finite.to_int e)) in
     let center = Finite.for_each add_center exponent (Vector.zero order) in
     let center i = Matrix.(base i * center |> norm) in
@@ -112,12 +106,4 @@ module Make (Field : Field.S) = struct
     let upper i inv = set_upper i (bound Field.one i) inv in
     Finite.(invariant order |> for_each lower order |> for_each upper order)
 
-  (* - Powers : k * n^3
-     - Input i : n x n mul + n x m mul + norm
-        -> n^2 + nm + 1 because of projection
-     - Center i : n x n mul + k * n x n add + n x n mul + norm
-        -> n^2 + kn^2 + n^2 + 1 because of projection
-     - Dimension i : O(kn^2 + nm)
-     Total : O(kn^3 + mn^2)
-   *)
 end
-- 
GitLab