Skip to content
Snippets Groups Projects
Commit 66fce593 authored by Maxime Jacquemin's avatar Maxime Jacquemin Committed by David Bühler
Browse files

[Kernel] Simplifications and improvements

- Change Finite.for_each parameters order for a more natural one
- Used Format boxes for some pretty_printers
- Used Finite.for_each instead of a nasty for loop in vector product
- Utilitary function for index computation in matrices
- Simpler computation of the spectral radius
parent d48aabe4
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ let to_int : type n. n finite -> int = fun n -> n ...@@ -36,7 +36,7 @@ let to_int : type n. n finite -> int = fun n -> n
let of_int : type n. n succ nat -> int -> n succ finite option = let of_int : type n. n succ nat -> int -> n succ finite option =
fun limit n -> if 0 <= n && n < Nat.to_int limit then Some n else None fun limit n -> if 0 <= n && n < Nat.to_int limit then Some n else None
let for_each (type n) acc (limit : n nat) (f : n finite -> 'a -> 'a) = let for_each (type n) (f : n finite -> 'a -> 'a) (limit : n nat) acc =
let acc = ref acc in let acc = ref acc in
for i = 0 to Nat.to_int limit - 1 do acc := f i !acc done ; for i = 0 to Nat.to_int limit - 1 do acc := f i !acc done ;
!acc !acc
...@@ -40,4 +40,4 @@ val to_int : 'n finite -> int ...@@ -40,4 +40,4 @@ val to_int : 'n finite -> int
(* The call [for_each acc limit f] folds over each finite elements of a set of (* The call [for_each acc limit f] folds over each finite elements of a set of
cardinal limit, computing f at each step. The function complexity is O(n). *) cardinal limit, computing f at each step. The function complexity is O(n). *)
val for_each : 'a -> 'n nat -> ('n finite -> 'a -> 'a) -> 'a val for_each : ('n finite -> 'a -> 'a) -> 'n nat -> 'a -> 'a
...@@ -43,26 +43,25 @@ module Space (Field : Field.S) = struct ...@@ -43,26 +43,25 @@ module Space (Field : Field.S) = struct
let init size f = let init size f =
let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in
let set i data = Parray.set data (Finite.to_int i) (f i) in let set i data = Parray.set data (Finite.to_int i) (f i) in
let data = Finite.for_each data size set in let data = Finite.for_each set size data in
M { data ; rows = size ; cols = Nat.one } M { data ; rows = size ; cols = Nat.one }
let size (type n) (M vector : n vector) : n nat = vector.rows let size (type n) (M vector : n vector) : n nat = vector.rows
let repeat n size = init size (fun _ -> n) let repeat n size = init size (fun _ -> n)
let zero size = repeat Field.zero size let zero size = repeat Field.zero size
let set (type n) (i : n finite) n (M vec : n vector) : n vector = let get (type n) (i : n finite) (M vec : n vector) : scalar =
M { vec with data = Parray.set vec.data (Finite.to_int i) n } Parray.get vec.data (Finite.to_int i)
let set (type n) (i : n finite) scalar (M vec : n vector) : n vector =
M { vec with data = Parray.set vec.data (Finite.to_int i) scalar }
let norm (type n) (M vector : n vector) = let norm (type n) (M vector : n vector) =
Parray.fold (fun _ a acc -> Field.(abs a + acc)) vector.data Field.zero Parray.fold (fun _ a acc -> Field.(abs a + acc)) vector.data Field.zero
let ( * ) (type n) (M l : n vector) (M r : n vector) = let ( * ) (type n) (l : n vector) (r : n vector) =
let inner = ref Field.zero in let inner i acc = Field.(acc + get i l * get i r) in
let get i v = Parray.get v.data i in Finite.for_each inner (size l) Field.zero
for i = 0 to Nat.to_int l.rows - 1 do
inner := Field.(!inner + (get i l) * (get i r))
done ;
!inner
end end
...@@ -70,13 +69,15 @@ module Space (Field : Field.S) = struct ...@@ -70,13 +69,15 @@ module Space (Field : Field.S) = struct
module Matrix = struct module Matrix = struct
let index cols i j = i * Nat.to_int cols + j
let get (type n m) (i : n finite) (j : m finite) (M m : (n, m) matrix) = let get (type n m) (i : n finite) (j : m finite) (M m : (n, m) matrix) =
let i = Finite.to_int i and j = Finite.to_int j in let i = Finite.to_int i and j = Finite.to_int j in
Parray.get m.data (i * Nat.to_int m.cols + j) Parray.get m.data (index m.cols i j)
let set (type n m) i j num (M m : (n, m) matrix) : (n, m) matrix = let set (type n m) i j num (M m : (n, m) matrix) : (n, m) matrix =
let i = Finite.to_int i and j = Finite.to_int j in let i = Finite.to_int i and j = Finite.to_int j in
let data = Parray.set m.data (i * Nat.to_int m.cols + j) num in let data = Parray.set m.data (index m.cols i j) num in
M { m with data } M { m with data }
let row row (M m) = Vector.init m.cols @@ fun i -> get row i (M m) let row row (M m) = Vector.init m.cols @@ fun i -> get row i (M m)
...@@ -86,20 +87,24 @@ module Space (Field : Field.S) = struct ...@@ -86,20 +87,24 @@ module Space (Field : Field.S) = struct
fun (M m) -> m.rows, m.cols fun (M m) -> m.rows, m.cols
let pretty (type n m) fmt (M m : (n, m) matrix) = let pretty (type n m) fmt (M m : (n, m) matrix) =
Finite.for_each () m.rows @@ fun i () -> let open Format in
(if Finite.(i == first) then () else Format.fprintf fmt "@ ") ; let not_first i = not Finite.(i == first) in
Format.fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) let newline fmt i = if not_first i then pp_print_newline fmt () in
let row fmt i = fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m)) in
let pp_line fmt i () = newline fmt i ; row fmt i in
let pp fmt () = Finite.for_each (pp_line fmt) m.rows () in
Format.fprintf fmt "@[<v>%a@]" pp ()
let init n m init = let init n m init =
let rows = Nat.to_int n and cols = Nat.to_int m in let rows = Nat.to_int n and cols = Nat.to_int m in
let t = Parray.init (rows * cols) (fun _ -> Field.zero) in let t = Parray.init (rows * cols) (fun _ -> Field.zero) in
let index i j = Finite.to_int i * cols + Finite.to_int j in let index i j = index m (Finite.to_int i) (Finite.to_int j) in
let set i j data = Parray.set data (index i j) (init i j) in let set i j data = Parray.set data (index i j) (init i j) in
let data = Finite.(for_each t n @@ fun i t -> for_each t m (set i)) in let data = Finite.(for_each (fun i t -> for_each (set i) m t) n t) in
M { data ; rows = n ; cols = m } M { data ; rows = n ; cols = m }
let zero n m = init n m (fun _ _ -> Field.zero) let zero n m = init n m (fun _ _ -> Field.zero)
let id n = Finite.for_each (zero n n) n @@ fun i m -> set i i Field.one m let id n = Finite.for_each (fun i m -> set i i Field.one m) n (zero n n)
type ('n, 'm) add = ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix type ('n, 'm) add = ('n, 'm) matrix -> ('n, 'm) matrix -> ('n, 'm) matrix
let ( + ) : type n m. (n, m) add = fun (M l) (M r) -> let ( + ) : type n m. (n, m) add = fun (M l) (M r) ->
...@@ -113,7 +118,7 @@ module Space (Field : Field.S) = struct ...@@ -113,7 +118,7 @@ module Space (Field : Field.S) = struct
let norm : type n m. (n, m) matrix -> scalar = fun (M m) -> let norm : type n m. (n, m) matrix -> scalar = fun (M m) ->
let max i res = Field.max res (col i (M m) |> Vector.norm) in let max i res = Field.max res (col i (M m) |> Vector.norm) in
Finite.for_each Field.zero m.cols max Finite.for_each max m.cols Field.zero
end end
......
...@@ -31,28 +31,25 @@ module Make (Field : Field.S) = struct ...@@ -31,28 +31,25 @@ module Make (Field : Field.S) = struct
let rec first_steps max steps matrix exponent = let rec first_steps target steps matrix exponent =
let steps = (matrix, exponent) :: steps in let steps' = (matrix, exponent) :: steps in
if Field.(Matrix.norm matrix < one) then if exponent * 2 > target
if exponent * 2 > max then Some steps then if exponent <= target then steps' else steps
else first_steps max steps Matrix.(matrix * matrix) (exponent * 2) else first_steps target steps' Matrix.(matrix * matrix) (exponent * 2)
else if exponent <= max then
first_steps max steps Matrix.(matrix * matrix) (exponent * 2) let rec refine target = function
else None
let rec refine max = function
| [] -> None | [] -> None
| [ (matrix, exponent) ] -> Some (Matrix.norm matrix, exponent) | [ (matrix, _) ] ->
let norm = Matrix.norm matrix in
if Field.(norm < one) then Some norm else None
| (matrix, exponent) :: (matrix', exponent') :: previous -> | (matrix, exponent) :: (matrix', exponent') :: previous ->
let exponent'' = exponent + exponent' in let exponent'' = exponent + exponent' in
if exponent'' > max if exponent'' > target
then refine max ((matrix, exponent) :: previous) then refine target ((matrix, exponent) :: previous)
else refine max ((Matrix.(matrix * matrix'), exponent'') :: previous) else refine target ((Matrix.(matrix * matrix'), exponent'') :: previous)
let find_exponent max_acceptable_exponent base = let find_spectral_radius target base =
let open Option.Operators in first_steps target [] base 1 |> refine target
let* steps = first_steps max_acceptable_exponent [] base 1 in
refine max_acceptable_exponent steps
...@@ -71,9 +68,11 @@ module Make (Field : Field.S) = struct ...@@ -71,9 +68,11 @@ module Make (Field : Field.S) = struct
type ('n, 'm) formatter = Format.formatter -> ('n, 'm) filter -> unit type ('n, 'm) formatter = Format.formatter -> ('n, 'm) filter -> unit
let pretty : type n m. (n, m) formatter = fun fmt (Filter f) -> let pretty : type n m. (n, m) formatter = fun fmt (Filter f) ->
Format.fprintf fmt "Filter:@.@." ; Format.fprintf fmt "@[<v>" ;
Format.fprintf fmt "- State :@.@. @[<v>%a@]@.@." Matrix.pretty f.state ; Format.fprintf fmt "Filter:@ @ " ;
Format.fprintf fmt "- Input :@.@. @[<v>%a@]@.@." Matrix.pretty f.input Format.fprintf fmt "- State :@ @ @[<v>%a@]@ @ " Matrix.pretty f.state ;
Format.fprintf fmt "- Input :@ @ @[<v>%a@]@ @ " Matrix.pretty f.input ;
Format.fprintf fmt "@]"
let sum order p norm stop = let sum order p norm stop =
let ( + ) res acc = Field.(res + norm acc) in let ( + ) res acc = Field.(res + norm acc) in
...@@ -81,13 +80,13 @@ module Make (Field : Field.S) = struct ...@@ -81,13 +80,13 @@ module Make (Field : Field.S) = struct
aux (Matrix.id order, Field.zero) (stop - 1) aux (Matrix.id order, Field.zero) (stop - 1)
type ('n, 'm) invariant = ('n, 'm) filter -> int -> Field.scalar option type ('n, 'm) invariant = ('n, 'm) filter -> int -> Field.scalar option
let invariant : type n m. (n, m) invariant = fun (Filter f) max -> let invariant : type n m. (n, m) invariant = fun (Filter f) exponent ->
let open Option.Operators in let open Option.Operators in
let order, _ = Matrix.dimensions f.input in let order, _ = Matrix.dimensions f.input in
let+ spectral, exponant = find_exponent max f.state in let+ spectral = find_spectral_radius exponent f.state in
let power p = Matrix.(f.state * p) in let power p = Matrix.(f.state * p) in
let norm p = Matrix.(p * f.input |> norm) in let norm p = Matrix.(p * f.input |> norm) in
let sum = sum order power norm exponant in let sum = sum order power norm exponent in
let bound = Field.(Vector.norm f.measure * sum / (one - spectral)) in let bound = Field.(Vector.norm f.measure * sum / (one - spectral)) in
let order = Field.of_int (Nat.to_int order) in let order = Field.of_int (Nat.to_int order) in
Field.(bound / order) Field.(bound / order)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment