From 08de5dc6944ac5f90d66a6748fc6b8b7887c5711 Mon Sep 17 00:00:00 2001 From: Maxime Jacquemin <maxime.jacquemin@cea.fr> Date: Fri, 2 Feb 2024 10:56:06 +0100 Subject: [PATCH] [Kernel] Simpler and directly memoized Matrix powers --- src/kernel_services/analysis/filter/linear.ml | 32 ++++++++----------- .../analysis/filter/linear.mli | 3 +- .../analysis/filter/linear_filter.ml | 22 +++---------- 3 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml index 629e482e1c9..1dd2fc44fd0 100644 --- a/src/kernel_services/analysis/filter/linear.ml +++ b/src/kernel_services/analysis/filter/linear.ml @@ -147,25 +147,19 @@ module Space (Field : Field.S) = struct let max i res = Field.max res (row i (M m) |> sum) in Finite.for_each max m.rows Field.zero - let rec fast_power target steps matrix exponent = - let steps' = (matrix, exponent) :: steps in - let exponent' = Stdlib.(exponent * 2) in - if exponent' > target then if exponent <= target then steps' else steps - else fast_power target steps' (matrix * matrix) exponent' - - let rec refine target = function - | [] -> assert false - | [ (matrix, _) ] -> matrix - | (matrix, exponent) :: (matrix', exponent') :: previous -> - let exponent'' = Stdlib.(exponent + exponent') in - if exponent'' > target - then refine target ((matrix, exponent) :: previous) - else refine target (((matrix * matrix'), exponent'') :: previous) - - let power (type n) (M matrix : (n, n) matrix) exponent : (n, n) matrix = - if exponent < 0 then raise (Invalid_argument "negative exponent") - else if Stdlib.(exponent = 0) then id matrix.rows - else fast_power exponent [] (M matrix) 1 |> refine exponent + let power (type n) (M m : (n, n) matrix) : int -> (n, n) matrix = + let n = dimensions (M m) |> fst in + let cache = Datatype.Int.Hashtbl.create 17 in + let find i = Datatype.Int.Hashtbl.find_opt cache i in + let save i v = Datatype.Int.Hashtbl.add cache i v ; v in + let rec pow e = + if e < 0 then raise (Invalid_argument "negative exponent") ; + match find e with + | Some r -> r + | None when Stdlib.(e = 0) -> id n + | None when Stdlib.(e = 1) -> M m + | None -> let h = pow (e / 2) in save e (pow (e mod 2) * h * h) + in pow end diff --git a/src/kernel_services/analysis/filter/linear.mli b/src/kernel_services/analysis/filter/linear.mli index 3ab99f29e66..173b2c9bb4d 100644 --- a/src/kernel_services/analysis/filter/linear.mli +++ b/src/kernel_services/analysis/filter/linear.mli @@ -52,7 +52,8 @@ module Space (Field : Field.S) : sig val dimensions : ('m, 'n) matrix -> 'm nat * 'n nat val ( + ) : ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix val ( * ) : ('n, 'm) matrix -> ('m, 'p) matrix -> ('n, 'p) matrix - val power : ('n, 'n) matrix -> int -> ('n, 'n) matrix + (* Memoized, instantiate first on a matrix and then use it *) + val power : ('n, 'n) matrix -> (int -> ('n, 'n) matrix) end end diff --git a/src/kernel_services/analysis/filter/linear_filter.ml b/src/kernel_services/analysis/filter/linear_filter.ml index f216a5f5ca9..197357393b3 100644 --- a/src/kernel_services/analysis/filter/linear_filter.ml +++ b/src/kernel_services/analysis/filter/linear_filter.ml @@ -77,12 +77,6 @@ module Make (Field : Field.S) = struct - let memoized_powers m = - let cache = Datatype.Int.Hashtbl.create 17 in - let find i = Datatype.Int.Hashtbl.find cache i in - let save i v = Datatype.Int.Hashtbl.add cache i v ; v in - fun i -> try find i with Not_found -> save i (Matrix.power m i) - let check_convergence matrix = let norm = Matrix.norm matrix in if Field.(norm < one) then Some norm else None @@ -90,18 +84,18 @@ module Make (Field : Field.S) = struct type ('n, 'm) compute = ('n, 'm) filter -> int -> 'n invariant option let invariant : type n m. (n, m) compute = fun (Filter f) e -> let open Option.Operators in + let state = Matrix.power f.state in let measure = Vector.norm f.measure in - let powers = memoized_powers f.state in let order, _ = Matrix.dimensions f.input in let base i = Vector.base i order |> Matrix.transpose in let* StrictlyPositive exponent = Nat.of_strictly_positive_int e in - let+ spectral = powers (Nat.to_int exponent) |> check_convergence in + let+ spectral = state (Nat.to_int exponent) |> check_convergence in (* Computation of the inputs contribution for the i-th state dimension *) - let input base e = Matrix.(base * powers e * f.input |> norm) in + let input base e = Matrix.(base * state e * f.input |> norm) in let add_input base e res = Field.(res + input base (Finite.to_int e)) in let input i = Finite.for_each (base i |> add_input) exponent Field.zero in (* Computation of the center contribution for the i-th state dimension *) - let center e = Matrix.(powers e * f.center) in + let center e = Matrix.(state e * f.center) in let add_center e res = Matrix.(res + center (Finite.to_int e)) in let center = Finite.for_each add_center exponent (Vector.zero order) in let center i = Matrix.(base i * center |> norm) in @@ -112,12 +106,4 @@ module Make (Field : Field.S) = struct let upper i inv = set_upper i (bound Field.one i) inv in Finite.(invariant order |> for_each lower order |> for_each upper order) - (* - Powers : k * n^3 - - Input i : n x n mul + n x m mul + norm - -> n^2 + nm + 1 because of projection - - Center i : n x n mul + k * n x n add + n x n mul + norm - -> n^2 + kn^2 + n^2 + 1 because of projection - - Dimension i : O(kn^2 + nm) - Total : O(kn^3 + mn^2) - *) end -- GitLab