From 84d0102ccfc5b33577cea27913bb0af1fcaf1f97 Mon Sep 17 00:00:00 2001
From: Maxime Jacquemin <maxime.jacquemin@cea.fr>
Date: Wed, 10 Jan 2024 12:53:33 +0100
Subject: [PATCH] [Kernel] Field as datatype and improved complexity for finite

---
 src/kernel_services/analysis/filter/field.ml  |  5 +--
 src/kernel_services/analysis/filter/finite.ml | 34 +++++++++++--------
 .../analysis/filter/finite.mli                |  8 +++--
 src/kernel_services/analysis/filter/linear.ml | 14 ++++----
 4 files changed, 33 insertions(+), 28 deletions(-)

diff --git a/src/kernel_services/analysis/filter/field.ml b/src/kernel_services/analysis/filter/field.ml
index b65a11387b0..fd00c817945 100644
--- a/src/kernel_services/analysis/filter/field.ml
+++ b/src/kernel_services/analysis/filter/field.ml
@@ -23,12 +23,9 @@
 module type S = sig
 
   module Types : sig type scalar end
+  include Datatype.S_with_collections with type t = Types.scalar
   open Types
 
-  val compare : scalar -> scalar -> int
-  val equal : scalar -> scalar -> bool
-  val pretty : Format.formatter -> scalar -> unit
-
   val zero     : scalar
   val one      : scalar
   val infinity : scalar
diff --git a/src/kernel_services/analysis/filter/finite.ml b/src/kernel_services/analysis/filter/finite.ml
index ae751948a48..06b52b28a72 100644
--- a/src/kernel_services/analysis/filter/finite.ml
+++ b/src/kernel_services/analysis/filter/finite.ml
@@ -23,33 +23,37 @@
 open Nat.Types
 
 module Types = struct
-  type 'n finite = First : 'n succ finite | Next : 'n finite -> 'n succ finite
+  type 'n formal = First : 'n succ formal | Next : 'n formal -> 'n succ formal
+  type 'n finite = { value : int ; formal : 'n formal }
 end
 
 open Types
 
-let rec weaken : type n. n finite -> n succ finite =
-  function First -> First | Next n -> Next (weaken n)
+let first = { value = 0 ; formal = First }
+let next { value ; formal } = { value = value + 1 ; formal = Next formal }
+let to_int { value ; _ } = value
+let ( == ) l r = l.value = r.value
 
 let rec of_int : type n. n succ nat -> int -> n succ finite = fun limit n ->
   match limit with
-  | Succ Zero -> First
-  | Succ (Succ _) when n <= 0 -> First
-  | Succ (Succ limit) -> Next (of_int (Succ limit) (n - 1))
-
-let to_int finite =
-  let rec aux : type n. int -> n finite -> int = fun acc ->
-    function First -> acc | Next n -> aux (acc + 1) n
-  in aux 0 finite
+  | Succ Zero -> first
+  | Succ Succ _ when n <= 0 -> first
+  | Succ Succ limit -> next (of_int (Succ limit) (n - 1))
 
 let rec of_nat : type n. n succ nat -> n succ finite = function
-  | Succ Zero -> First
-  | Succ (Succ n) -> Next (of_nat (Succ n))
+  | Succ Zero -> first
+  | Succ (Succ n) -> next (of_nat (Succ n))
+
+(* We use Obj.magic here to avoid the O(n) long but trivial proof *)
+let weaken : type n. n finite -> n succ finite =
+  fun { value ; formal } -> { value ; formal = Obj.magic formal }
 
+(* Non tail-rec to perform the computation in the natural order *)
 let rec fold f n acc =
   match n with
-  | First -> f First acc
-  | Next n -> f (Next n) (fold f (weaken n) acc)
+  | { formal = First ; _ } as n -> f n acc
+  | { formal = Next formal ; value } as n ->
+    f n (fold f (weaken { value = value - 1 ; formal }) acc)
 
 let for_each (type n) acc (n : n nat) (f : n finite -> 'a -> 'a) =
   match n with Zero -> acc | Succ _ as n -> fold f (of_nat n) acc
diff --git a/src/kernel_services/analysis/filter/finite.mli b/src/kernel_services/analysis/filter/finite.mli
index 46485ca4fc1..64f56b9c068 100644
--- a/src/kernel_services/analysis/filter/finite.mli
+++ b/src/kernel_services/analysis/filter/finite.mli
@@ -22,12 +22,14 @@
 
 open Nat.Types
 
-module Types : sig
-  type 'n finite = First : 'n succ finite | Next : 'n finite -> 'n succ finite
-end
+module Types : sig type 'n finite end
 
 open Types
 
+val first : 'n succ finite
+val next : 'n finite -> 'n succ finite
+val ( == ) : 'n finite -> 'n finite -> bool
+
 val weaken : 'n finite -> 'n succ finite
 val of_int : 'n succ nat -> int -> 'n succ finite
 val to_int : 'n finite -> int
diff --git a/src/kernel_services/analysis/filter/linear.ml b/src/kernel_services/analysis/filter/linear.ml
index 3368505e2da..50a828d9d91 100644
--- a/src/kernel_services/analysis/filter/linear.ml
+++ b/src/kernel_services/analysis/filter/linear.ml
@@ -47,8 +47,9 @@ module Space (Field : Field.S) = struct
       Parray.pretty ~pp_sep Field.pretty fmt data
 
     let init size f =
-      let index n = Finite.of_int size n |> f in
-      let data = Parray.init (Nat.to_int size) index in
+      let data = Parray.init (Nat.to_int size) (fun _ -> Field.zero) in
+      let set i data = Parray.set data (Finite.to_int i) (f i) in
+      let data = Finite.for_each data size set in
       M { data ; rows = size ; cols = Nat.one }
 
     let size (type n) (M vector : n vector) : n nat = vector.rows
@@ -92,14 +93,15 @@ module Space (Field : Field.S) = struct
 
     let pretty (type n m) fmt (M m : (n, m) matrix) =
       Finite.for_each () m.rows @@ fun i () ->
-      (match i with First -> () | Next _ -> Format.fprintf fmt "@ ") ;
+      (if Finite.(i == first) then () else Format.fprintf fmt "@ ") ;
       Format.fprintf fmt "@[<h>%a@]" Vector.pretty (row i (M m))
 
     let init n m init =
       let rows = Nat.to_int n and cols = Nat.to_int m in
-      let row i = Finite.of_int n (i  /  cols) in
-      let col i = Finite.of_int m (i mod cols) in
-      let data = Parray.init (rows * cols) @@ fun i -> init (row i) (col i) in
+      let t = Parray.init (rows * cols) (fun _ -> Field.zero) in
+      let index i j = Finite.to_int i * cols + Finite.to_int j in
+      let set i j data = Parray.set data (index i j) (init i j) in
+      let data = Finite.(for_each t n @@ fun i t -> for_each t m (set i)) in
       M { data ; rows = n ; cols = m }
 
     let zero n m = init n m (fun _ _ -> Field.zero)
-- 
GitLab