From ac2d7d78846a57f08cc0e6a2bb06e9790278727e Mon Sep 17 00:00:00 2001
From: Maxime Jacquemin <maxime.jacquemin@cea.fr>
Date: Fri, 2 Feb 2024 10:36:03 +0100
Subject: [PATCH] [Kernel] Detailled comments and method improvements

We can now represents filters that are not centered around zero.
The invariant computation now computes bounds for each state dimension.
The module is documented.
---
 src/kernel_services/analysis/filter/finite.ml |   1 +
 .../analysis/filter/finite.mli                |   1 +
 src/kernel_services/analysis/filter/linear.ml |  70 ++++++++++--
 .../analysis/filter/linear.mli                |  14 ++-
 .../analysis/filter/linear_filter.ml          | 107 +++++++++++-------
 .../analysis/filter/linear_filter.mli         |  65 ++++++++---
 .../analysis/filter/linear_filter_test.ml     |  24 +++-
 7 files changed, 207 insertions(+), 75 deletions(-)

diff --git a/src/kernel_services/analysis/filter/finite.ml b/src/kernel_services/analysis/filter/finite.ml
index 5ad748ccb5b..2799fe8141f 100644
--- a/src/kernel_services/analysis/filter/finite.ml
+++ b/src/kernel_services/analysis/filter/finite.ml
@@ -25,6 +25,7 @@ open Nat
 type 'n finite = int
 
 let first  : type n. n succ finite = 0
+let last   : type n. n succ nat -> n succ finite = fun n -> Nat.to_int n - 1
 let next   : type n. n finite -> n succ finite = fun n -> n + 1
 let ( = )  : type n. n finite -> n finite -> bool = fun l r -> l = r
 let to_int : type n. n finite -> int = fun n -> n
diff --git a/src/kernel_services/analysis/filter/finite.mli b/src/kernel_services/analysis/filter/finite.mli
index 28b54b04e5a..0a3b5fb6949 100644
--- a/src/kernel_services/analysis/filter/finite.mli
+++ b/src/kernel_services/analysis/filter/finite.mli
@@ -25,6 +25,7 @@ open Nat
 type 'n finite
 
 val first : 'n succ finite
+val last  : 'n succ nat -> 'n succ finite
 val next  : 'n finite -> 'n succ finite
 val ( = ) : 'n finite -> 'n finite -> bool
 
diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml
index 8cd5288d91d..629e482e1c9 100644
--- a/src/kernel_services/analysis/filter/linear.ml
+++ b/src/kernel_services/analysis/filter/linear.ml
@@ -35,10 +35,29 @@ module Space (Field : Field.S) = struct
 
 
 
+  type 'n row = Format.formatter -> 'n finite -> unit
+  let pretty (type n m) (row : n row) fmt (M m : (n, m) matrix) =
+    let cut () = Format.pp_print_cut fmt () in
+    let first () = Format.fprintf fmt "@[<h>⌈%a⌉@]" row Finite.first in
+    let mid i = Format.fprintf fmt "@[<h>|%a|@]" row i in
+    let last () = Format.fprintf fmt "@[<h>⌋%a⌊@]" row Finite.(last m.rows) in
+    let row i () =
+      if Finite.(i = first) then first ()
+      else if Finite.(i = last m.rows) then (cut () ; last ())
+      else (cut () ; mid i)
+    in
+    Format.pp_open_vbox fmt 0 ;
+    Finite.for_each row m.rows () ;
+    Format.pp_close_box fmt ()
+
+
+
   module Vector = struct
 
-    let pretty (type n) fmt (M { data ; _ } : n vector) =
-      Parray.pretty ~sep:"@ " Field.pretty fmt data
+    let pretty_row (type n) fmt (M { data ; _ } : n vector) =
+      Format.pp_open_hbox fmt () ;
+      Parray.pretty ~sep:"@ " Field.pretty fmt data ;
+      Format.pp_close_box fmt ()
 
     let init size f =
       let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in
@@ -53,16 +72,24 @@ module Space (Field : Field.S) = struct
     let get (type n) (i : n finite) (M vec : n vector) : scalar =
       Parray.get vec.data (Finite.to_int i)
 
+    let pretty (type n) fmt (vector : n vector) =
+      let get fmt (i : n finite) = Field.pretty fmt (get i vector) in
+      pretty get fmt vector
+
     let set (type n) (i : n finite) scalar (M vec : n vector) : n vector =
       M { vec with data = Parray.set vec.data (Finite.to_int i) scalar }
 
-    let norm (type n) (M vector : n vector) =
-      Parray.fold (fun _ a acc -> Field.(abs a + acc)) vector.data Field.zero
+    let norm (type n) (v : n vector) : scalar =
+      let max i r = Field.(max (abs (get i v)) r) in
+      Finite.for_each max (size v) Field.zero
 
     let ( * ) (type n) (l : n vector) (r : n vector) =
       let inner i acc = Field.(acc + get i l * get i r) in
       Finite.for_each inner (size l) Field.zero
 
+    let base (type n) (i : n succ finite) (dimension : n succ nat) =
+      zero dimension |> set i Field.one
+
   end
 
 
@@ -87,10 +114,8 @@ module Space (Field : Field.S) = struct
       fun (M m) -> m.rows, m.cols
 
     let pretty (type n m) fmt (M m : (n, m) matrix) =
-      let cut i = if not Finite.(i = first) then Format.pp_print_cut fmt () in
-      let row i = Format.fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) in
-      let pretty () = Finite.for_each (fun i () -> cut i ; row i) m.rows () in
-      Format.pp_open_vbox fmt 0 ; pretty () ; Format.pp_close_box fmt ()
+      let row fmt i = Vector.pretty_row fmt (row i (M m)) in
+      pretty row fmt (M m)
 
     let init n m init =
       let rows = Nat.to_int n and cols = Nat.to_int m in
@@ -103,6 +128,9 @@ module Space (Field : Field.S) = struct
     let zero n m = init n m (fun _ _ -> Field.zero)
     let id n = Finite.for_each (fun i m -> set i i Field.one m) n (zero n n)
 
+    let transpose : type n m. (n, m) matrix -> (m, n) matrix =
+      fun (M m) -> init m.cols m.rows (fun j i -> get i j (M m))
+
     type ('n, 'm) add = ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix
     let ( + ) : type n m. (n, m) add = fun (M l) (M r) ->
       let ( + ) i j = Field.(get i j (M l) + get i j (M r)) in
@@ -114,8 +142,30 @@ module Space (Field : Field.S) = struct
       init l.rows r.cols ( * )
 
     let norm : type n m. (n, m) matrix -> scalar = fun (M m) ->
-      let max i res = Field.max res (col i (M m) |> Vector.norm) in
-      Finite.for_each max m.cols Field.zero
+      let add v j r = Field.(abs (Vector.get j v) + r) in
+      let sum v = Finite.for_each (add v) (Vector.size v) Field.zero in
+      let max i res = Field.max res (row i (M m) |> sum) in
+      Finite.for_each max m.rows Field.zero
+
+    let rec fast_power target steps matrix exponent =
+      let steps' = (matrix, exponent) :: steps in
+      let exponent' = Stdlib.(exponent * 2) in
+      if exponent' > target then if exponent <= target then steps' else steps
+      else fast_power target steps' (matrix * matrix) exponent'
+
+    let rec refine target = function
+      | [] -> assert false
+      | [ (matrix, _) ] -> matrix
+      | (matrix, exponent) :: (matrix', exponent') :: previous ->
+        let exponent'' = Stdlib.(exponent + exponent') in
+        if exponent'' > target
+        then refine target ((matrix, exponent) :: previous)
+        else refine target (((matrix * matrix'), exponent'') :: previous)
+
+    let power (type n) (M matrix : (n, n) matrix) exponent : (n, n) matrix =
+      if exponent < 0 then raise (Invalid_argument "negative exponent")
+      else if Stdlib.(exponent = 0) then id matrix.rows
+      else fast_power exponent [] (M matrix) 1 |> refine exponent
 
   end
 
diff --git a/src/kernel_services/analysis/filter/linear.mli b/src/kernel_services/analysis/filter/linear.mli
index 7c4ee980aad..3ab99f29e66 100644
--- a/src/kernel_services/analysis/filter/linear.mli
+++ b/src/kernel_services/analysis/filter/linear.mli
@@ -33,22 +33,26 @@ module Space (Field : Field.S) : sig
 
   module Vector : sig
     val pretty : Format.formatter -> 'n vector -> unit
-    val size : 'n vector -> 'n nat
-    val norm : 'n vector -> scalar
-    val zero : 'n succ nat -> 'n succ vector
+    val zero   : 'n succ nat -> 'n succ vector
+    val base   : 'n succ finite -> 'n succ nat -> 'n succ vector
     val repeat : scalar -> 'n succ nat -> 'n succ vector
-    val set : 'n finite -> scalar -> 'n vector -> 'n vector
+    val set    : 'n finite -> scalar -> 'n vector -> 'n vector
+    val size   : 'n vector -> 'n nat
+    val norm   : 'n vector -> scalar
   end
 
   module Matrix : sig
     val pretty : Format.formatter -> ('n, 'm) matrix -> unit
     val id : 'n succ nat -> ('n succ, 'n succ) matrix
     val zero : 'n succ nat -> 'm succ nat -> ('n succ, 'm succ) matrix
+    val get : 'n finite -> 'm finite -> ('n, 'm) matrix -> scalar
     val set : 'n finite -> 'm finite -> scalar -> ('n, 'm) matrix -> ('n, 'm) matrix
+    val norm : ('n, 'm) matrix -> scalar
+    val transpose : ('n, 'm) matrix -> ('m, 'n) matrix
     val dimensions : ('m, 'n) matrix -> 'm nat * 'n nat
     val ( + ) : ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix
     val ( * ) : ('n, 'm) matrix -> ('m, 'p) matrix -> ('n, 'p) matrix
-    val norm : ('n, 'm) matrix -> scalar
+    val power : ('n, 'n) matrix -> int -> ('n, 'n) matrix
   end
 
 end
diff --git a/src/kernel_services/analysis/filter/linear_filter.ml b/src/kernel_services/analysis/filter/linear_filter.ml
index 4a32b30dd04..f216a5f5ca9 100644
--- a/src/kernel_services/analysis/filter/linear_filter.ml
+++ b/src/kernel_services/analysis/filter/linear_filter.ml
@@ -20,36 +20,13 @@
 (*                                                                        *)
 (**************************************************************************)
 
-open Nat
-
 
 
 module Make (Field : Field.S) = struct
 
   module Linear = Linear.Space (Field)
   open Linear
-
-
-
-  let rec first_steps target steps matrix exponent =
-    let steps' = (matrix, exponent) :: steps in
-    if exponent * 2 > target
-    then if exponent <= target then steps' else steps
-    else first_steps target steps' Matrix.(matrix * matrix) (exponent * 2)
-
-  let rec refine target = function
-    | [] -> None
-    | [ (matrix, _) ] ->
-      let norm = Matrix.norm matrix in
-      if Field.(norm < one) then Some norm else None
-    | (matrix, exponent) :: (matrix', exponent') :: previous ->
-      let exponent'' = exponent + exponent' in
-      if exponent'' > target
-      then refine target ((matrix, exponent) :: previous)
-      else refine target ((Matrix.(matrix * matrix'), exponent'') :: previous)
-
-  let find_spectral_radius target base =
-    first_steps target [] base 1 |> refine target
+  open Nat
 
 
 
@@ -59,12 +36,14 @@ module Make (Field : Field.S) = struct
   and ('n, 'm) data =
     { state : ('n, 'n) matrix
     ; input : ('n, 'm) matrix
+    ; center  : 'n vector
     ; measure : 'm vector
     }
 
 
 
-  let create state input measure = Filter { state ; input ; measure }
+  let create ~state ~input ~center ~measure =
+    Filter { state ; input ; center ; measure }
 
   type ('n, 'm) formatter = Format.formatter -> ('n, 'm) filter -> unit
   let pretty : type n m. (n, m) formatter = fun fmt (Filter f) ->
@@ -74,21 +53,71 @@ module Make (Field : Field.S) = struct
     Format.fprintf fmt "- Input :@ @   @[<v>%a@]@ @ " Matrix.pretty f.input ;
     Format.fprintf fmt "@]"
 
-  let sum order p norm stop =
-    let ( + ) res acc = Field.(res + norm acc) in
-    let rec aux (m, r) i = if i >= 0 then aux (p m, r + m) (i - 1) else r in
-    aux (Matrix.id order, Field.zero) (stop - 1)
 
-  type ('n, 'm) invariant = ('n, 'm) filter -> int -> Field.scalar option
-  let invariant : type n m. (n, m) invariant = fun (Filter f) exponent ->
+
+  type 'n invariant = ('n, zero succ succ) matrix
+
+  let invariant rows =
+    Matrix.zero rows Nat.(zero |> succ |> succ)
+
+  let lower i invariant =
+    Matrix.get i Finite.first invariant
+
+  let upper i invariant =
+    Matrix.get i Finite.(next first) invariant
+
+  let bounds i invariant =
+    (lower i invariant, upper i invariant)
+
+  let set_lower i bound invariant =
+    Matrix.set i Finite.first bound invariant
+
+  let set_upper i bound invariant =
+    Matrix.set i Finite.(next first)  bound invariant
+
+
+
+  let memoized_powers m =
+    let cache = Datatype.Int.Hashtbl.create 17 in
+    let find i = Datatype.Int.Hashtbl.find cache i in
+    let save i v = Datatype.Int.Hashtbl.add cache i v ; v in
+    fun i -> try find i with Not_found -> save i (Matrix.power m i)
+
+  let check_convergence matrix =
+    let norm = Matrix.norm matrix in
+    if Field.(norm < one) then Some norm else None
+
+  type ('n, 'm) compute = ('n, 'm) filter -> int -> 'n invariant option
+  let invariant : type n m. (n, m) compute = fun (Filter f) e ->
     let open Option.Operators in
+    let measure = Vector.norm f.measure in
+    let powers = memoized_powers f.state in
     let order, _ = Matrix.dimensions f.input in
-    let+ spectral = find_spectral_radius exponent f.state in
-    let power p = Matrix.(f.state * p) in
-    let norm  p = Matrix.(p * f.input |> norm) in
-    let sum = sum order power norm exponent in
-    let bound = Field.(Vector.norm f.measure * sum / (one - spectral)) in
-    let order = Field.of_int (Nat.to_int order) in
-    Field.(bound / order)
-
+    let base i = Vector.base i order |> Matrix.transpose in
+    let* StrictlyPositive exponent = Nat.of_strictly_positive_int e in
+    let+ spectral = powers (Nat.to_int exponent) |> check_convergence in
+    (* Computation of the inputs contribution for the i-th state dimension *)
+    let input base e = Matrix.(base * powers e * f.input |> norm) in
+    let add_input base e res = Field.(res + input base (Finite.to_int e)) in
+    let input i = Finite.for_each (base i |> add_input) exponent Field.zero in
+    (* Computation of the center contribution for the i-th state dimension *)
+    let center e = Matrix.(powers e * f.center) in
+    let add_center e res = Matrix.(res + center (Finite.to_int e)) in
+    let center = Finite.for_each add_center exponent (Vector.zero order) in
+    let center i = Matrix.(base i * center |> norm) in
+    (* Bounds computation for each state dimension *)
+    let numerator sign i = Field.(center i + sign * measure * input i) in
+    let bound sign i = Field.(numerator sign i / (one - spectral)) in
+    let lower i inv = set_lower i (bound Field.(neg one) i) inv in
+    let upper i inv = set_upper i (bound Field.one i) inv in
+    Finite.(invariant order |> for_each lower order |> for_each upper order)
+
+  (* - Powers : k * n^3
+     - Input i : n x n mul + n x m mul + norm
+        -> n^2 + nm + 1 because of projection
+     - Center i : n x n mul + k * n x n add + n x n mul + norm
+        -> n^2 + kn^2 + n^2 + 1 because of projection
+     - Dimension i : O(kn^2 + nm)
+     Total : O(kn^3 + mn^2)
+   *)
 end
diff --git a/src/kernel_services/analysis/filter/linear_filter.mli b/src/kernel_services/analysis/filter/linear_filter.mli
index 6a16b9b8f5b..e22d3f7bfb2 100644
--- a/src/kernel_services/analysis/filter/linear_filter.mli
+++ b/src/kernel_services/analysis/filter/linear_filter.mli
@@ -21,7 +21,22 @@
 (**************************************************************************)
 
 open Nat
+open Finite
 
+(* A filter corresponds to the recursive equation
+   X[k + 1] = AX[k] + Bε[k + 1] + C where :
+   - n is the filter's order and m its number of inputs ;
+   - X[k]∈ ℝ^n is the filter's state at iteration [k] ;
+   - ε[k]∈ ℝ^m is the filters's inputs at iteration [k] ;
+   - A∈ ℝ^(n×n) is the filter's state matrix ;
+   - B∈ ℝ^(n×m) is the filter's input matrix ;
+   - C∈ ℝ^n is the filter's center.
+
+   The goal of this module is to compute filters invariants, i.e bounds for
+   each of the filter's state dimensions when the iterations goes to infinity.
+   To do so, it only suppose that, at each iteration, each input εi is bounded
+   by [-λi .. λi]. Each input is thus supposed centered around zero but each
+   one can have different bounds. *)
 module Make (Field : Field.S) : sig
 
   module Linear : module type of Linear.Space (Field)
@@ -30,25 +45,45 @@ module Make (Field : Field.S) : sig
      with n state variables) and with m inputs. *)
   type ('n, 'm) filter
 
+  (* Create a filter's representation. The inputs are as following :
+     - state is the filter's state matrix ;
+     - input is the filter's input matrix ;
+     - center is the filter's center ;
+     - measure is a vector representing upper bounds for the filter's inputs. *)
   val create :
-    ('n succ, 'n succ) Linear.matrix ->
-    ('n succ, 'm succ) Linear.matrix ->
-    'm succ Linear.vector -> ('n succ, 'm succ) filter
+    state : ('n succ, 'n succ) Linear.matrix ->
+    input : ('n succ, 'm succ) Linear.matrix ->
+    center : 'n succ Linear.vector ->
+    measure : 'm succ Linear.vector ->
+    ('n succ, 'm succ) filter
 
   val pretty : Format.formatter -> ('n, 'm) filter -> unit
 
-  (* Invariant computation. The computation of [invariant filter max] relies on
+  (* Representation of a filter's invariant. Bounds for each dimension can be
+     accessed using the corresponding functions. *)
+  type 'n invariant
+  val lower : 'n finite -> 'n invariant -> Field.scalar
+  val upper : 'n finite -> 'n invariant -> Field.scalar
+  val bounds : 'n finite -> 'n invariant -> Field.scalar * Field.scalar
+
+  (* Invariant computation. The computation of [invariant filter k] relies on
      the search of an exponent such as the norm of the state matrix is strictly
-     lower than one. This search depth is bounded by [max]. If no exponent is
-     found before this limit is reached, the function returns None. If an
-     exponent [e] is found, the invariant computation complexity is bounded by
-     O(e * (n^3 + n^2 * m)) with [n] the filter's order and [m] its number of
-     inputs. Only the invariant's upper bound [λ] is returned, the filter's
-     invariant is thus bounded by ±λ. The only thing that limit the optimality
-     of this bound is [max], the initial search depth. However, for most simple
-     filters, a depth of 200 will gives an exact upper bound up to at least ten
-     digits, which is more than enough. Moreover, for those simple filters, the
-     computation is immediate, even when using rational numbers. *)
-  val invariant : ('n, 'm) filter -> int -> Field.scalar option
+     lower than one. For the filter to converge, there must exist an α such as,
+     for every β greater than α, ||A^β|| < 1 with A the filter's state matrix.
+     As such, the search does not have to find α, but instead any exponent such
+     as the property is satisfies. As the computed invariant will be more
+     precise with a larger exponent, the computation always uses [k], the
+     largest authorized exponent, and thus only check that indeed ||A^k|| < 1.
+     If the property is not verified, the function returns None as it cannot
+     prove that the filter converges.
+
+     The only thing limiting the invariant optimality is [k]. However, for most
+     simple filters, k = 200 will gives exact bounds up to at least ten digits,
+     which is more than enough. Moreover, for those simple filters, the
+     computation is immediate, even when using rational numbers. Indeed, the
+     invariant computation complexity is bounded by O(kn^3 + mn^2) with [n]
+     the filter's order and [m] its number of inputs. It is thus linear in
+     the targeted exponent. *)
+  val invariant : ('n, 'm) filter -> int -> 'n invariant option
 
 end
diff --git a/src/kernel_services/analysis/filter/linear_filter_test.ml b/src/kernel_services/analysis/filter/linear_filter_test.ml
index 84149c657d9..a6c6d12f0c1 100644
--- a/src/kernel_services/analysis/filter/linear_filter_test.ml
+++ b/src/kernel_services/analysis/filter/linear_filter_test.ml
@@ -66,9 +66,17 @@ let max_exponent = 200
 let fin size n = Finite.of_int size n |> Option.get
 let set row col i j n = Linear.Matrix.set (fin row i) (fin col j) n
 
-let pretty_invariant fmt = function
+let pretty_bounds invariant fmt i =
+  let l, u = Filter.bounds i invariant in
+  Format.fprintf fmt "@[<h>[%a .. %a]@]" Rational.pretty l Rational.pretty u
+
+let pretty_invariant order fmt = function
   | None -> Format.fprintf fmt "%s" (Unicode.top_string ())
-  | Some invariant -> Format.fprintf fmt "%a" Rational.pretty invariant
+  | Some invariant ->
+    let pp f i = pretty_bounds invariant f i in
+    let pp f i = Format.fprintf f "@[<h>* %d : %a@]@," (Finite.to_int i) pp i in
+    let pretty fmt () = Finite.for_each (fun i () -> pp fmt i) order () in
+    Format.fprintf fmt "@[<v>%a@]" pretty ()
 
 
 
@@ -87,10 +95,12 @@ module Circle = struct
 
   let measure = Linear.Vector.repeat Rational.one order
 
+  let center = Linear.Vector.zero order
+
   let compute () =
-    let filter = Filter.create state input measure in
+    let filter = Filter.create state input center measure in
     let invariant = Filter.invariant filter max_exponent in
-    Kernel.result "Circle : %a@." pretty_invariant invariant
+    Kernel.result "@[<v>Circle :@,%a@,@]" (pretty_invariant order) invariant
 
 end
 
@@ -115,10 +125,12 @@ module Simple = struct
 
   let measure = Linear.Vector.repeat (Rational.of_float 0.1) delay
 
+  let center = Linear.Vector.repeat Rational.one order
+
   let compute () =
-    let filter = Filter.create state input measure in
+    let filter = Filter.create state input center measure in
     let invariant = Filter.invariant filter max_exponent in
-    Kernel.result "Simple : %a@." pretty_invariant invariant
+    Kernel.result "@[<v>Simple :@,%a@,@]" (pretty_invariant order) invariant
 
 end
 
-- 
GitLab