From 66fce59317f8ff8e0990c42d3f1d5e6615d09bb3 Mon Sep 17 00:00:00 2001 From: Maxime Jacquemin <maxime.jacquemin@cea.fr> Date: Wed, 31 Jan 2024 15:43:07 +0100 Subject: [PATCH] [Kernel] Simplifications and improvements - Change Finite.for_each parameters order for a more natural one - Used Format boxes for some pretty_printers - Used Finite.for_each instead of a nasty for loop in vector product - Utilitary function for index computation in matrices - Simpler computation of the spectral radius --- src/kernel_services/analysis/filter/finite.ml | 2 +- .../analysis/filter/finite.mli | 2 +- src/kernel_services/analysis/filter/linear.ml | 43 +++++++++-------- .../analysis/filter/linear_filter.ml | 47 +++++++++---------- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/src/kernel_services/analysis/filter/finite.ml b/src/kernel_services/analysis/filter/finite.ml index a3851bc9961..11af2a7281b 100644 --- a/src/kernel_services/analysis/filter/finite.ml +++ b/src/kernel_services/analysis/filter/finite.ml @@ -36,7 +36,7 @@ let to_int : type n. n finite -> int = fun n -> n let of_int : type n. n succ nat -> int -> n succ finite option = fun limit n -> if 0 <= n && n < Nat.to_int limit then Some n else None -let for_each (type n) acc (limit : n nat) (f : n finite -> 'a -> 'a) = +let for_each (type n) (f : n finite -> 'a -> 'a) (limit : n nat) acc = let acc = ref acc in for i = 0 to Nat.to_int limit - 1 do acc := f i !acc done ; !acc diff --git a/src/kernel_services/analysis/filter/finite.mli b/src/kernel_services/analysis/filter/finite.mli index 36d4f1ebd55..ed5e8b695cb 100644 --- a/src/kernel_services/analysis/filter/finite.mli +++ b/src/kernel_services/analysis/filter/finite.mli @@ -40,4 +40,4 @@ val to_int : 'n finite -> int (* The call [for_each acc limit f] folds over each finite elements of a set of cardinal limit, computing f at each step. The function complexity is O(n). *) -val for_each : 'a -> 'n nat -> ('n finite -> 'a -> 'a) -> 'a +val for_each : ('n finite -> 'a -> 'a) -> 'n nat -> 'a -> 'a diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml index 4fde14ffda3..b4a6430d512 100644 --- a/src/kernel_services/analysis/filter/linear.ml +++ b/src/kernel_services/analysis/filter/linear.ml @@ -43,26 +43,25 @@ module Space (Field : Field.S) = struct let init size f = let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in let set i data = Parray.set data (Finite.to_int i) (f i) in - let data = Finite.for_each data size set in + let data = Finite.for_each set size data in M { data ; rows = size ; cols = Nat.one } let size (type n) (M vector : n vector) : n nat = vector.rows let repeat n size = init size (fun _ -> n) let zero size = repeat Field.zero size - let set (type n) (i : n finite) n (M vec : n vector) : n vector = - M { vec with data = Parray.set vec.data (Finite.to_int i) n } + let get (type n) (i : n finite) (M vec : n vector) : scalar = + Parray.get vec.data (Finite.to_int i) + + let set (type n) (i : n finite) scalar (M vec : n vector) : n vector = + M { vec with data = Parray.set vec.data (Finite.to_int i) scalar } let norm (type n) (M vector : n vector) = Parray.fold (fun _ a acc -> Field.(abs a + acc)) vector.data Field.zero - let ( * ) (type n) (M l : n vector) (M r : n vector) = - let inner = ref Field.zero in - let get i v = Parray.get v.data i in - for i = 0 to Nat.to_int l.rows - 1 do - inner := Field.(!inner + (get i l) * (get i r)) - done ; - !inner + let ( * ) (type n) (l : n vector) (r : n vector) = + let inner i acc = Field.(acc + get i l * get i r) in + Finite.for_each inner (size l) Field.zero end @@ -70,13 +69,15 @@ module Space (Field : Field.S) = struct module Matrix = struct + let index cols i j = i * Nat.to_int cols + j + let get (type n m) (i : n finite) (j : m finite) (M m : (n, m) matrix) = let i = Finite.to_int i and j = Finite.to_int j in - Parray.get m.data (i * Nat.to_int m.cols + j) + Parray.get m.data (index m.cols i j) let set (type n m) i j num (M m : (n, m) matrix) : (n, m) matrix = let i = Finite.to_int i and j = Finite.to_int j in - let data = Parray.set m.data (i * Nat.to_int m.cols + j) num in + let data = Parray.set m.data (index m.cols i j) num in M { m with data } let row row (M m) = Vector.init m.cols @@ fun i -> get row i (M m) @@ -86,20 +87,24 @@ module Space (Field : Field.S) = struct fun (M m) -> m.rows, m.cols let pretty (type n m) fmt (M m : (n, m) matrix) = - Finite.for_each () m.rows @@ fun i () -> - (if Finite.(i == first) then () else Format.fprintf fmt "@ ") ; - Format.fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) + let open Format in + let not_first i = not Finite.(i == first) in + let newline fmt i = if not_first i then pp_print_newline fmt () in + let row fmt i = fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) in + let pp_line fmt i () = newline fmt i ; row fmt i in + let pp fmt () = Finite.for_each (pp_line fmt) m.rows () in + Format.fprintf fmt "@[<v>%a@]" pp () let init n m init = let rows = Nat.to_int n and cols = Nat.to_int m in let t = Parray.init (rows * cols) (fun _ -> Field.zero) in - let index i j = Finite.to_int i * cols + Finite.to_int j in + let index i j = index m (Finite.to_int i) (Finite.to_int j) in let set i j data = Parray.set data (index i j) (init i j) in - let data = Finite.(for_each t n @@ fun i t -> for_each t m (set i)) in + let data = Finite.(for_each (fun i t -> for_each (set i) m t) n t) in M { data ; rows = n ; cols = m } let zero n m = init n m (fun _ _ -> Field.zero) - let id n = Finite.for_each (zero n n) n @@ fun i m -> set i i Field.one m + let id n = Finite.for_each (fun i m -> set i i Field.one m) n (zero n n) type ('n, 'm) add = ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix let ( + ) : type n m. (n, m) add = fun (M l) (M r) -> @@ -113,7 +118,7 @@ module Space (Field : Field.S) = struct let norm : type n m. (n, m) matrix -> scalar = fun (M m) -> let max i res = Field.max res (col i (M m) |> Vector.norm) in - Finite.for_each Field.zero m.cols max + Finite.for_each max m.cols Field.zero end diff --git a/src/kernel_services/analysis/filter/linear_filter.ml b/src/kernel_services/analysis/filter/linear_filter.ml index da1722d5d25..8bfad7c20c6 100644 --- a/src/kernel_services/analysis/filter/linear_filter.ml +++ b/src/kernel_services/analysis/filter/linear_filter.ml @@ -31,28 +31,25 @@ module Make (Field : Field.S) = struct - let rec first_steps max steps matrix exponent = - let steps = (matrix, exponent) :: steps in - if Field.(Matrix.norm matrix < one) then - if exponent * 2 > max then Some steps - else first_steps max steps Matrix.(matrix * matrix) (exponent * 2) - else if exponent <= max then - first_steps max steps Matrix.(matrix * matrix) (exponent * 2) - else None - - let rec refine max = function + let rec first_steps target steps matrix exponent = + let steps' = (matrix, exponent) :: steps in + if exponent * 2 > target + then if exponent <= target then steps' else steps + else first_steps target steps' Matrix.(matrix * matrix) (exponent * 2) + + let rec refine target = function | [] -> None - | [ (matrix, exponent) ] -> Some (Matrix.norm matrix, exponent) + | [ (matrix, _) ] -> + let norm = Matrix.norm matrix in + if Field.(norm < one) then Some norm else None | (matrix, exponent) :: (matrix', exponent') :: previous -> let exponent'' = exponent + exponent' in - if exponent'' > max - then refine max ((matrix, exponent) :: previous) - else refine max ((Matrix.(matrix * matrix'), exponent'') :: previous) + if exponent'' > target + then refine target ((matrix, exponent) :: previous) + else refine target ((Matrix.(matrix * matrix'), exponent'') :: previous) - let find_exponent max_acceptable_exponent base = - let open Option.Operators in - let* steps = first_steps max_acceptable_exponent [] base 1 in - refine max_acceptable_exponent steps + let find_spectral_radius target base = + first_steps target [] base 1 |> refine target @@ -71,9 +68,11 @@ module Make (Field : Field.S) = struct type ('n, 'm) formatter = Format.formatter -> ('n, 'm) filter -> unit let pretty : type n m. (n, m) formatter = fun fmt (Filter f) -> - Format.fprintf fmt "Filter:@.@." ; - Format.fprintf fmt "- State :@.@. @[<v>%a@]@.@." Matrix.pretty f.state ; - Format.fprintf fmt "- Input :@.@. @[<v>%a@]@.@." Matrix.pretty f.input + Format.fprintf fmt "@[<v>" ; + Format.fprintf fmt "Filter:@ @ " ; + Format.fprintf fmt "- State :@ @ @[<v>%a@]@ @ " Matrix.pretty f.state ; + Format.fprintf fmt "- Input :@ @ @[<v>%a@]@ @ " Matrix.pretty f.input ; + Format.fprintf fmt "@]" let sum order p norm stop = let ( + ) res acc = Field.(res + norm acc) in @@ -81,13 +80,13 @@ module Make (Field : Field.S) = struct aux (Matrix.id order, Field.zero) (stop - 1) type ('n, 'm) invariant = ('n, 'm) filter -> int -> Field.scalar option - let invariant : type n m. (n, m) invariant = fun (Filter f) max -> + let invariant : type n m. (n, m) invariant = fun (Filter f) exponent -> let open Option.Operators in let order, _ = Matrix.dimensions f.input in - let+ spectral, exponant = find_exponent max f.state in + let+ spectral = find_spectral_radius exponent f.state in let power p = Matrix.(f.state * p) in let norm p = Matrix.(p * f.input |> norm) in - let sum = sum order power norm exponant in + let sum = sum order power norm exponent in let bound = Field.(Vector.norm f.measure * sum / (one - spectral)) in let order = Field.of_int (Nat.to_int order) in Field.(bound / order) -- GitLab