From 66fce59317f8ff8e0990c42d3f1d5e6615d09bb3 Mon Sep 17 00:00:00 2001
From: Maxime Jacquemin <maxime.jacquemin@cea.fr>
Date: Wed, 31 Jan 2024 15:43:07 +0100
Subject: [PATCH] [Kernel] Simplifications and improvements

- Change Finite.for_each parameters order for a more natural one
- Used Format boxes for some pretty_printers
- Used Finite.for_each instead of a nasty for loop in vector product
- Utilitary function for index computation in matrices
- Simpler computation of the spectral radius
---
 src/kernel_services/analysis/filter/finite.ml |  2 +-
 .../analysis/filter/finite.mli                |  2 +-
 src/kernel_services/analysis/filter/linear.ml | 43 +++++++++--------
 .../analysis/filter/linear_filter.ml          | 47 +++++++++----------
 4 files changed, 49 insertions(+), 45 deletions(-)

diff --git a/src/kernel_services/analysis/filter/finite.ml b/src/kernel_services/analysis/filter/finite.ml
index a3851bc9961..11af2a7281b 100644
--- a/src/kernel_services/analysis/filter/finite.ml
+++ b/src/kernel_services/analysis/filter/finite.ml
@@ -36,7 +36,7 @@ let to_int : type n. n finite -> int = fun n -> n
 let of_int : type n. n succ nat -> int -> n succ finite option =
   fun limit n -> if 0 <= n && n < Nat.to_int limit then Some n else None
 
-let for_each (type n) acc (limit : n nat) (f : n finite -> 'a -> 'a) =
+let for_each (type n) (f : n finite -> 'a -> 'a) (limit : n nat) acc =
   let acc = ref acc in
   for i = 0 to Nat.to_int limit - 1 do acc := f i !acc done ;
   !acc
diff --git a/src/kernel_services/analysis/filter/finite.mli b/src/kernel_services/analysis/filter/finite.mli
index 36d4f1ebd55..ed5e8b695cb 100644
--- a/src/kernel_services/analysis/filter/finite.mli
+++ b/src/kernel_services/analysis/filter/finite.mli
@@ -40,4 +40,4 @@ val to_int : 'n finite -> int
 
 (* The call [for_each acc limit f] folds over each finite elements of a set of
    cardinal limit, computing f at each step. The function complexity is O(n). *)
-val for_each : 'a -> 'n nat -> ('n finite -> 'a -> 'a) -> 'a
+val for_each : ('n finite -> 'a -> 'a) -> 'n nat -> 'a -> 'a
diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml
index 4fde14ffda3..b4a6430d512 100644
--- a/src/kernel_services/analysis/filter/linear.ml
+++ b/src/kernel_services/analysis/filter/linear.ml
@@ -43,26 +43,25 @@ module Space (Field : Field.S) = struct
     let init size f =
       let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in
       let set i data = Parray.set data (Finite.to_int i) (f i) in
-      let data = Finite.for_each data size set in
+      let data = Finite.for_each set size data in
       M { data ; rows = size ; cols = Nat.one }
 
     let size (type n) (M vector : n vector) : n nat = vector.rows
     let repeat n size = init size (fun _ -> n)
     let zero size = repeat Field.zero size
 
-    let set (type n) (i : n finite) n (M vec : n vector) : n vector =
-      M { vec with data = Parray.set vec.data (Finite.to_int i) n }
+    let get (type n) (i : n finite) (M vec : n vector) : scalar =
+      Parray.get vec.data (Finite.to_int i)
+
+    let set (type n) (i : n finite) scalar (M vec : n vector) : n vector =
+      M { vec with data = Parray.set vec.data (Finite.to_int i) scalar }
 
     let norm (type n) (M vector : n vector) =
       Parray.fold (fun _ a acc -> Field.(abs a + acc)) vector.data Field.zero
 
-    let ( * ) (type n) (M l : n vector) (M r : n vector) =
-      let inner = ref Field.zero in
-      let get i v = Parray.get v.data i in
-      for i = 0 to Nat.to_int l.rows - 1 do
-        inner := Field.(!inner + (get i l) * (get i r))
-      done ;
-      !inner
+    let ( * ) (type n) (l : n vector) (r : n vector) =
+      let inner i acc = Field.(acc + get i l * get i r) in
+      Finite.for_each inner (size l) Field.zero
 
   end
 
@@ -70,13 +69,15 @@ module Space (Field : Field.S) = struct
 
   module Matrix = struct
 
+    let index cols i j = i * Nat.to_int cols + j
+
     let get (type n m) (i : n finite) (j : m finite) (M m : (n, m) matrix) =
       let i = Finite.to_int i and j = Finite.to_int j in
-      Parray.get m.data (i * Nat.to_int m.cols + j)
+      Parray.get m.data (index m.cols i j)
 
     let set (type n m) i j num (M m : (n, m) matrix) : (n, m) matrix =
       let i = Finite.to_int i and j = Finite.to_int j in
-      let data = Parray.set m.data (i * Nat.to_int m.cols + j) num in
+      let data = Parray.set m.data (index m.cols i j) num in
       M { m with data }
 
     let row row (M m) = Vector.init m.cols @@ fun i -> get row i (M m)
@@ -86,20 +87,24 @@ module Space (Field : Field.S) = struct
       fun (M m) -> m.rows, m.cols
 
     let pretty (type n m) fmt (M m : (n, m) matrix) =
-      Finite.for_each () m.rows @@ fun i () ->
-      (if Finite.(i == first) then () else Format.fprintf fmt "@ ") ;
-      Format.fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m))
+      let open Format in
+      let not_first i = not Finite.(i == first) in
+      let newline fmt i = if not_first i then pp_print_newline fmt () in
+      let row fmt i = fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) in
+      let pp_line fmt i () = newline fmt i ; row fmt i in
+      let pp fmt () = Finite.for_each (pp_line fmt) m.rows () in
+      Format.fprintf fmt "@[<v>%a@]" pp ()
 
     let init n m init =
       let rows = Nat.to_int n and cols = Nat.to_int m in
       let t = Parray.init (rows * cols) (fun _ -> Field.zero) in
-      let index i j = Finite.to_int i * cols + Finite.to_int j in
+      let index i j = index m (Finite.to_int i) (Finite.to_int j) in
       let set i j data = Parray.set data (index i j) (init i j) in
-      let data = Finite.(for_each t n @@ fun i t -> for_each t m (set i)) in
+      let data = Finite.(for_each (fun i t -> for_each (set i) m t) n t) in
       M { data ; rows = n ; cols = m }
 
     let zero n m = init n m (fun _ _ -> Field.zero)
-    let id n = Finite.for_each (zero n n) n @@ fun i m -> set i i Field.one m
+    let id n = Finite.for_each (fun i m -> set i i Field.one m) n (zero n n)
 
     type ('n, 'm) add = ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix
     let ( + ) : type n m. (n, m) add = fun (M l) (M r) ->
@@ -113,7 +118,7 @@ module Space (Field : Field.S) = struct
 
     let norm : type n m. (n, m) matrix -> scalar = fun (M m) ->
       let max i res = Field.max res (col i (M m) |> Vector.norm) in
-      Finite.for_each Field.zero m.cols max
+      Finite.for_each max m.cols Field.zero
 
   end
 
diff --git a/src/kernel_services/analysis/filter/linear_filter.ml b/src/kernel_services/analysis/filter/linear_filter.ml
index da1722d5d25..8bfad7c20c6 100644
--- a/src/kernel_services/analysis/filter/linear_filter.ml
+++ b/src/kernel_services/analysis/filter/linear_filter.ml
@@ -31,28 +31,25 @@ module Make (Field : Field.S) = struct
 
 
 
-  let rec first_steps max steps matrix exponent =
-    let steps = (matrix, exponent) :: steps in
-    if Field.(Matrix.norm matrix < one) then
-      if exponent * 2 > max then Some steps
-      else first_steps max steps Matrix.(matrix * matrix) (exponent * 2)
-    else if exponent <= max then
-      first_steps max steps Matrix.(matrix * matrix) (exponent * 2)
-    else None
-
-  let rec refine max = function
+  let rec first_steps target steps matrix exponent =
+    let steps' = (matrix, exponent) :: steps in
+    if exponent * 2 > target
+    then if exponent <= target then steps' else steps
+    else first_steps target steps' Matrix.(matrix * matrix) (exponent * 2)
+
+  let rec refine target = function
     | [] -> None
-    | [ (matrix, exponent) ] -> Some (Matrix.norm matrix, exponent)
+    | [ (matrix, _) ] ->
+      let norm = Matrix.norm matrix in
+      if Field.(norm < one) then Some norm else None
     | (matrix, exponent) :: (matrix', exponent') :: previous ->
       let exponent'' = exponent + exponent' in
-      if exponent'' > max
-      then refine max ((matrix, exponent) :: previous)
-      else refine max ((Matrix.(matrix * matrix'), exponent'') :: previous)
+      if exponent'' > target
+      then refine target ((matrix, exponent) :: previous)
+      else refine target ((Matrix.(matrix * matrix'), exponent'') :: previous)
 
-  let find_exponent max_acceptable_exponent base =
-    let open Option.Operators in
-    let* steps = first_steps max_acceptable_exponent [] base 1 in
-    refine max_acceptable_exponent steps
+  let find_spectral_radius target base =
+    first_steps target [] base 1 |> refine target
 
 
 
@@ -71,9 +68,11 @@ module Make (Field : Field.S) = struct
 
   type ('n, 'm) formatter = Format.formatter -> ('n, 'm) filter -> unit
   let pretty : type n m. (n, m) formatter = fun fmt (Filter f) ->
-    Format.fprintf fmt "Filter:@.@." ;
-    Format.fprintf fmt "- State :@.@.  @[<v>%a@]@.@." Matrix.pretty f.state ;
-    Format.fprintf fmt "- Input :@.@.  @[<v>%a@]@.@." Matrix.pretty f.input
+    Format.fprintf fmt "@[<v>" ;
+    Format.fprintf fmt "Filter:@ @ " ;
+    Format.fprintf fmt "- State :@ @   @[<v>%a@]@ @ " Matrix.pretty f.state ;
+    Format.fprintf fmt "- Input :@ @   @[<v>%a@]@ @ " Matrix.pretty f.input ;
+    Format.fprintf fmt "@]"
 
   let sum order p norm stop =
     let ( + ) res acc = Field.(res + norm acc) in
@@ -81,13 +80,13 @@ module Make (Field : Field.S) = struct
     aux (Matrix.id order, Field.zero) (stop - 1)
 
   type ('n, 'm) invariant = ('n, 'm) filter -> int -> Field.scalar option
-  let invariant : type n m. (n, m) invariant = fun (Filter f) max ->
+  let invariant : type n m. (n, m) invariant = fun (Filter f) exponent ->
     let open Option.Operators in
     let order, _ = Matrix.dimensions f.input in
-    let+ spectral, exponant = find_exponent max f.state in
+    let+ spectral = find_spectral_radius exponent f.state in
     let power p = Matrix.(f.state * p) in
     let norm  p = Matrix.(p * f.input |> norm) in
-    let sum = sum order power norm exponant in
+    let sum = sum order power norm exponent in
     let bound = Field.(Vector.norm f.measure * sum / (one - spectral)) in
     let order = Field.of_int (Nat.to_int order) in
     Field.(bound / order)
-- 
GitLab