From 7e95ff144383e0c716394f7ba1c23c11da2fe8e9 Mon Sep 17 00:00:00 2001
From: hra687261 <hichem.ait-el-hara@ocamlpro.com>
Date: Tue, 18 Oct 2022 11:13:27 +0200
Subject: [PATCH] [Array] fix a couple of bugs

---
 colibri2/bin/options.ml          | 13 +++--
 colibri2/stdlib/debug.ml         | 63 ++++++----------------
 colibri2/stdlib/flags.ml         | 65 +++++++++++++++++++++++
 colibri2/stdlib/flags.mli        | 15 ++++++
 colibri2/stdlib/std_sig.ml       | 19 ++++++-
 colibri2/theories/array/array.ml | 90 +++++++++++++++++++++++---------
 6 files changed, 190 insertions(+), 75 deletions(-)
 create mode 100644 colibri2/stdlib/flags.ml
 create mode 100644 colibri2/stdlib/flags.mli

diff --git a/colibri2/bin/options.ml b/colibri2/bin/options.ml
index bb0f7552c..54b90bc2d 100644
--- a/colibri2/bin/options.ml
+++ b/colibri2/bin/options.ml
@@ -48,7 +48,7 @@ let gc_opts minor_heap_size major_heap_increment space_overhead max_overhead
 
 let mk_state theories gc gc_opt bt colors time_limit size_limit input_lang
     input_mode input header_check header_licenses header_lang_version type_check
-    debug max_warn step_limit debug_flags show_steps check_status
+    debug max_warn step_limit debug_flags flags show_steps check_status
     dont_print_result learning limit_last_effort print_success negate_goal
     options =
   let last_effort_limit =
@@ -67,8 +67,11 @@ let mk_state theories gc gc_opt bt colors time_limit size_limit input_lang
                  (int_of_float (f *. float step_limit))))
   in
   List.iter
-    Colibri2_stdlib.Debug.(fun s -> set_flag (lookup_flag s))
+    Colibri2_stdlib.Flags.Debug.(fun s -> set_flag (lookup_flag s))
     debug_flags;
+  List.iter
+    Colibri2_stdlib.Flags.Solve.(fun s -> set_flag (lookup_flag s))
+    flags;
   (if debug then
    Colibri2_stdlib.Debug.(
      List.iter (fun (_, f, info, _) -> if info then set_flag f) (list_flags ())));
@@ -382,6 +385,10 @@ let state theories =
     let doc = Format.asprintf "Debug flags." in
     Arg.(value & opt_all string [] & info [ "debug-flag" ] ~docs ~doc)
   in
+  let flags =
+    let doc = Format.asprintf "Solving flags." in
+    Arg.(value & opt_all string [] & info [ "flag" ] ~docs ~doc)
+  in
   let check_status =
     let doc =
       Format.asprintf
@@ -447,6 +454,6 @@ let state theories =
     const (mk_state theories)
     $ gc $ gc_t $ bt $ colors $ time $ size $ in_lang $ in_mode $ input
     $ header_check $ header_licenses $ header_lang_version $ typing $ debug
-    $ max_warn $ step_limit $ debug_flags $ show_steps $ check_status
+    $ max_warn $ step_limit $ debug_flags $ flags $ show_steps $ check_status
     $ dont_print_result $ learning $ last_effort_limit $ print_success
     $ negate_goal $ other_options)
diff --git a/colibri2/stdlib/debug.ml b/colibri2/stdlib/debug.ml
index 8ad188f34..3bab44b25 100644
--- a/colibri2/stdlib/debug.ml
+++ b/colibri2/stdlib/debug.ml
@@ -26,62 +26,37 @@ let () =
          print_endline "Stopped by user";
          exit 1))
 
-exception UnknownFlag of string
+type flag = Flags.flag
 
-type flag = bool ref
+let _true = Flags._true
+
+include Flags.Debug
 
-let _true = ref true
 let todo = _true
-let modifiable s = not (Base.phys_equal s _true)
-let flag_table = Hashtbl.create 17
-let fst3 (flag, _, _) = flag
-let snd3 (_, info, _) = info
-let thd3 (_, _, desc) = desc
 
 let gen_register_flag (desc : (unit, unit, unit, unit, unit, unit) format6) s
     info =
-  try fst3 (Hashtbl.find flag_table s)
-  with Not_found ->
-    let flag = ref false in
-    Hashtbl.replace flag_table s (flag, info, desc);
-    flag
+  gen_register_flag ~rep:(fun flag -> (flag, info, desc)) s
 
 let register_info_flag ~desc s = gen_register_flag desc s true
 let register_flag ~desc s = gen_register_flag desc s false
-
-let list_flags () =
-  Hashtbl.fold
-    (fun s (v, info, desc) acc -> (s, v, info, desc) :: acc)
-    flag_table []
-
-let lookup_flag s =
-  try fst3 (Hashtbl.find flag_table s) with Not_found -> raise (UnknownFlag s)
+let list_flags () = map_list (fun s (f, i, d) -> (s, f, i, d))
 
 let is_info_flag s =
-  try snd3 (Hashtbl.find flag_table s) with Not_found -> raise (UnknownFlag s)
+  match lookup s with
+  | _, i, _ -> i
+  | exception Not_found -> raise (Flags.UnknownFlag s)
 
 let flag_desc s =
-  try thd3 (Hashtbl.find flag_table s) with Not_found -> raise (UnknownFlag s)
-
-let test_flag s = !s
-let test_noflag s = not !s
-
-let set_flag s =
-  assert (modifiable s);
-  s := true
-
-let unset_flag s =
-  assert (modifiable s);
-  s := false
-
-let toggle_flag s =
-  assert (modifiable s);
-  s := not !s
+  match lookup s with
+  | _, _, desc -> desc
+  | exception Not_found -> raise (Flags.UnknownFlag s)
 
 let () =
   Printexc.register_printer (fun e ->
       match e with
-      | UnknownFlag s -> Some (Format.asprintf "unknown debug flag `%s'" s)
+      | Flags.UnknownFlag s ->
+          Some (Format.asprintf "unknown debug flag `%s'" s)
       | _ -> None)
 
 let stack_trace =
@@ -223,11 +198,7 @@ module Args = struct
     in
     let list () =
       (if !opt_list_flags then
-       let list =
-         Hashtbl.fold
-           (fun s (_, info, desc) acc -> (s, info, desc) :: acc)
-           flag_table []
-       in
+       let list = map_list (fun s (_, info, desc) -> (s, info, desc)) in
        let pp fmt (p, info, desc) =
          Format.fprintf fmt "@[%s%s  @[%( %)@]@]" p
            (if info then " *" else "")
@@ -245,9 +216,7 @@ module Args = struct
 
   let opt_list_flags = ref []
   let add_flag s = opt_list_flags := s :: !opt_list_flags
-
-  let add_all_flags () =
-    Hashtbl.iter (fun s (_, info, _) -> if info then add_flag s) flag_table
+  let add_all_flags () = iter (fun s (_, info, _) -> if info then add_flag s)
 
   let remove_flag s =
     opt_list_flags :=
diff --git a/colibri2/stdlib/flags.ml b/colibri2/stdlib/flags.ml
new file mode 100644
index 000000000..0763cc51d
--- /dev/null
+++ b/colibri2/stdlib/flags.ml
@@ -0,0 +1,65 @@
+type flag = bool ref
+
+let _true = ref true
+let todo = _true
+let modifiable s = not (Base.phys_equal s _true)
+
+exception UnknownFlag of string
+
+let () =
+  Printexc.register_printer (fun e ->
+      match e with
+      | UnknownFlag s -> Some (Format.asprintf "unknown debug flag `%s'" s)
+      | _ -> None)
+
+module Make (S : sig
+  type t
+
+  val flag : t -> bool ref
+end) : Std_sig.Flags with type t = S.t = struct
+  type t = S.t
+
+  let table : (string, t) Hashtbl.t = Hashtbl.create 17
+
+  let gen_register_flag ~rep s =
+    try S.flag (Hashtbl.find table s)
+    with Not_found ->
+      let flag = ref false in
+      Hashtbl.replace table s (rep flag);
+      flag
+
+  let lookup s = Hashtbl.find table s
+
+  let lookup_flag s =
+    try S.flag (lookup s) with Not_found -> raise (UnknownFlag s)
+
+  let to_list () = Hashtbl.fold (fun s v acc -> (s, v) :: acc) table []
+  let iter f = Hashtbl.iter f table
+  let map_list f = Hashtbl.map_list f table
+  let test_flag s = !s
+  let test_noflag s = not !s
+
+  let set_flag s =
+    assert (modifiable s);
+    s := true
+
+  let unset_flag s =
+    assert (modifiable s);
+    s := false
+
+  let toggle_flag s =
+    assert (modifiable s);
+    s := not !s
+end
+
+module Debug = Make (struct
+  type t = flag * bool * (unit, unit, unit, unit, unit, unit) format6
+
+  let flag ((f, _, _) : t) = f
+end)
+
+module Solve = Make (struct
+  type t = flag * (unit, unit, unit, unit, unit, unit) format6
+
+  let flag ((f, _) : t) = f
+end)
diff --git a/colibri2/stdlib/flags.mli b/colibri2/stdlib/flags.mli
new file mode 100644
index 000000000..b129948c6
--- /dev/null
+++ b/colibri2/stdlib/flags.mli
@@ -0,0 +1,15 @@
+type flag = bool ref
+
+val _true : bool ref
+val todo : bool ref
+val modifiable : bool ref -> bool
+
+exception UnknownFlag of string
+
+module Debug :
+  Std_sig.Flags
+    with type t = flag * bool * (unit, unit, unit, unit, unit, unit) format6
+
+module Solve :
+  Std_sig.Flags
+    with type t = flag * (unit, unit, unit, unit, unit, unit) format6
diff --git a/colibri2/stdlib/std_sig.ml b/colibri2/stdlib/std_sig.ml
index c766b99c0..b4ad085b4 100644
--- a/colibri2/stdlib/std_sig.ml
+++ b/colibri2/stdlib/std_sig.ml
@@ -39,9 +39,26 @@ end
 
 (* module type Datatype = sig
  *   include OrderedHashedType
- * 
+ *
  *   module M : Map_intf.PMap with type key = t
  *   module S : Map_intf.Set with type 'a M.t = 'a M.t
  *                            and type M.key = M.key
  *   module H : Exthtbl.Hashtbl.S with type key = t
  * end *)
+
+module type Flags = sig
+  type t
+
+  val table : (string, t) Hashtbl.t
+  val gen_register_flag : rep:(bool ref -> t) -> string -> bool ref
+  val lookup : string -> t
+  val lookup_flag : string -> bool ref
+  val to_list : unit -> (string * t) list
+  val iter : (string -> t -> unit) -> unit
+  val map_list : (string -> t -> 'a) -> 'a list
+  val test_flag : 'a ref -> 'a
+  val test_noflag : bool ref -> bool
+  val set_flag : bool ref -> unit
+  val unset_flag : bool ref -> unit
+  val toggle_flag : bool ref -> unit
+end
diff --git a/colibri2/theories/array/array.ml b/colibri2/theories/array/array.ml
index 1bf368e48..f16b4f608 100644
--- a/colibri2/theories/array/array.ml
+++ b/colibri2/theories/array/array.ml
@@ -22,10 +22,26 @@
 open Colibri2_core
 open Colibri2_popop_lib
 
-let restrict_ext = ref false
-let restrict_aup = ref false
-let extended_comb = ref false
-let default_values = ref false
+let restrict_ext =
+  Colibri2_stdlib.Flags.Solve.gen_register_flag
+    ~rep:(fun flag -> (flag, "Restrict the extensionality rule"))
+    "Array.res-ext"
+
+let restrict_aup =
+  Colibri2_stdlib.Flags.Solve.gen_register_flag
+    ~rep:(fun flag -> (flag, "Restrict the ⇑ rule"))
+    "res-aup"
+
+let extended_comb =
+  Colibri2_stdlib.Flags.Solve.gen_register_flag
+    ~rep:(fun flag -> (flag, "Extended combinators"))
+    "ext-comb"
+
+let default_values =
+  Colibri2_stdlib.Flags.Solve.gen_register_flag
+    ~rep:(fun flag -> (flag, "Default values"))
+    "def-values"
+
 let db = Datastructure.Push.create Ground.pp "Array.db"
 let debug = Debug.register_info_flag ~desc:"For array theory" "Array"
 let stats = Debug.register_stats_int "Array.rule"
@@ -84,30 +100,48 @@ module Builtin = struct
             app2 env s array_diff
         | Dolmen_loop.Typer.T.Id { ns = Term; name = Simple "array_const" } ->
             app1 env s array_const
+        | Dolmen_loop.Typer.T.Id
+            { ns = Term; name = Simple "array_default_index" } ->
+            app1 env s array_default_index
+        | Dolmen_loop.Typer.T.Id
+            { ns = Term; name = Simple "array_default_value" } ->
+            app1 env s array_default_value
         | _ -> `Not_found)
 end
 
 (* Helper functions  *)
 let is_array env n =
-  Ground.Ty.S.exists
-    (function { app = { builtin = Expr.Array; _ }; _ } -> true | _ -> false)
-    (Ground.tys env n)
+  let res =
+    Ground.Ty.S.exists
+      (function { app = { builtin = Expr.Array; _ }; _ } -> true | _ -> false)
+      (Ground.tys env n)
+  in
+  Debug.dprintf3 debug "is_array %a: %b" Node.pp n res;
+  res
 
 let is_foreign env n =
-  match Egraph.get_dom env Foreign_dom.key n with
-  | Some IsForeign -> true
-  | _ -> false
+  let res =
+    match Egraph.get_dom env Foreign_dom.key n with
+    | Some IsForeign -> true
+    | _ -> false
+  in
+  Debug.dprintf3 debug "is_foreign %a: %b" Node.pp n res;
+  res
 
 let is_nonlinear env n =
-  match Egraph.get_dom env Linearity_dom.key n with
-  | Some NonLinear -> true
-  | _ -> false
+  let res =
+    match Egraph.get_dom env Linearity_dom.key n with
+    | Some NonLinear -> true
+    | _ -> false
+  in
+  Debug.dprintf3 debug "is_nonlinear %a: %b" Node.pp n res;
+  res
 
 let apply_cst cst args = Expr.Term.apply_cst cst [] args
 let apply_diff a b = apply_cst Builtin.array_diff [ a; b ]
 let apply_const v = apply_cst Builtin.array_const [ v ]
 let apply_def_index a = apply_cst Builtin.array_default_index [ a ]
-let apply_def_value a = apply_cst Builtin.array_default_index [ a ]
+let apply_def_value a = apply_cst Builtin.array_default_value [ a ]
 
 (* Generalized, Efficient Array Decision Procedures. de Moura, Bjorner *)
 module Theory = struct
@@ -131,13 +165,20 @@ module Theory = struct
   let ta = term_of_var va
   let tb = term_of_var vb
 
-  let distinct_term =
-    Expr.Term._or
-      [
-        Expr.Term.eq ta tb;
-        (let diff = apply_diff ta tb in
-         Expr.Term.neq (Expr.Term.select ta diff) (Expr.Term.select tb diff));
-      ]
+  let distinct_term_node ~subst env ta tb =
+    let diff_term = apply_diff ta tb in
+    let diff_eq = Expr.Term.eq diff_term (apply_diff tb ta) in
+    let diff_eq_node = convert ~subst env diff_eq in
+    Egraph.register env diff_eq_node;
+    Boolean.set_true env diff_eq_node;
+    convert ~subst env
+    @@ Expr.Term._or
+         [
+           Expr.Term.eq ta tb;
+           Expr.Term.neq
+             (Expr.Term.select ta diff_term)
+             (Expr.Term.select tb diff_term);
+         ]
 
   (* ⇓: a ≡ b[i <- v], a[j] |> (i = j) \/ a[j] = b[j] *)
   let adown_pattern, adown_run =
@@ -214,7 +255,7 @@ module Theory = struct
         if
           (is_registered && is_false) || (is_foreign env an && is_foreign env bn)
         then (
-          let v = convert ~subst env distinct_term in
+          let v = distinct_term_node ~subst env ta tb in
           Egraph.register env v;
           Boolean.set_true env v)
       in
@@ -275,12 +316,13 @@ module Theory = struct
     (const_pattern, const_run)
 
   let init env =
-    let l = [ (const_pattern, const_run); (adown_pattern, adown_run) ] in
+    let l = [ (adown_pattern, adown_run) ] in
     let l = if !restrict_ext then (rext_pattern, rext_run) :: l else l in
     let l =
       if !restrict_aup then (raup_pattern, raup_run) :: l
       else (aup_pattern, aup_run) :: l
     in
+    let l = if !extended_comb then (const_pattern, const_run) :: l else l in
     List.iter (fun (p, r) -> InvertedPath.add_callback env p r) l
 
   let new_array env f =
@@ -297,7 +339,7 @@ module Theory = struct
               }
           in
           Debug.dprintf2 debug "Found ext with %a" Ground.Subst.pp subst;
-          let n = convert ~subst env distinct_term in
+          let n = distinct_term_node ~subst env ta tb in
           Egraph.register env n;
           Boolean.set_true env n);
       Datastructure.Push.push db env f);
-- 
GitLab