Skip to content
Snippets Groups Projects
common.ml 18.62 KiB
(*************************************************************************)
(*  This file is part of Colibri2.                                       *)
(*                                                                       *)
(*  Copyright (C) 2014-2021                                              *)
(*    CEA   (Commissariat à l'énergie atomique et aux énergies           *)
(*           alternatives)                                               *)
(*    OCamlPro                                                           *)
(*                                                                       *)
(*  you can redistribute it and/or modify it under the terms of the GNU  *)
(*  Lesser General Public License as published by the Free Software      *)
(*  Foundation, version 2.1.                                             *)
(*                                                                       *)
(*  It is distributed in the hope that it will be useful,                *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of       *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *)
(*  GNU Lesser General Public License for more details.                  *)
(*                                                                       *)
(*  See the GNU Lesser General Public License version 2.1                *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).           *)
(*************************************************************************)

open Colibri2_core
open Colibri2_popop_lib
open Popop_stdlib
module HT = Datastructure.Hashtbl (DInt)

let no_wegraph =
  Options.register ~pp:Fmt.bool "Array.no-wegraph"
    Cmdliner.Arg.(
      value & flag
      & info [ "no-wegraph" ]
          ~doc:"Don't use the array theory's weak equivalency graph")

let no_res_ext =
  Options.register ~pp:Fmt.bool "Array.no-res-ext"
    Cmdliner.Arg.(
      value & flag
      & info [ "no-res-ext" ]
          ~doc:"Don't restrict the array extensionality rule")

let no_res_aup =
  Options.register ~pp:Fmt.bool "Array.res-aup"
    Cmdliner.Arg.(
      value & flag
      & info [ "no-res-aup" ]
          ~doc:"Don't restrict the array's ⇑ (select over store) rule")

let extended_comb =
  Options.register ~pp:Fmt.bool "Array.res-comb"
    Cmdliner.Arg.(
      value & flag & info [ "array-ext-comb" ] ~doc:"Extended combinators")

let blast_rule =
  Options.register ~pp:Fmt.bool "Array.blast-rule"
    Cmdliner.Arg.(
      value & flag
      & info [ "array-blast-rule" ]
          ~doc:"Use the array blast rule when it is suitable")

let default_values =
  Options.register ~pp:Fmt.bool "Array.def-values"
    Cmdliner.Arg.(
      value & flag
      & info [ "array-def-values" ]
          ~doc:"Use inference rules for default values")

let debug =
  Debug.register_info_flag ~desc:"Debugging messages of the array theory"
    "Array"
let convert ~subst env =
  Colibri2_theories_quantifiers.Subst.convert ~subst_old:Ground.Subst.empty
    ~subst_new:subst env None

(* exceptions *)
exception Not_An_Array of Node.t
exception Not_An_Array_gty of Ground.Ty.t
exception Not_An_Array_ty of Expr.ty
exception Type_Not_Set of Node.t
exception NoIdFound of string * Node.t
exception Not_a_neighbour of (Node.t * Node.t * Node.t)
exception Empty_neighbour_set of (Node.t * Node.t * Node.t)

let () =
  Printexc.register_printer (function
    | Not_An_Array n -> Some (Fmt.str "%a is not an array!" Node.pp n)
    | Not_An_Array_gty gty ->
        Some (Fmt.str "%a is not an array ground type!" Ground.Ty.pp gty)
    | Not_An_Array_ty ty ->
        Some (Fmt.str "%a is not an array type!" Expr.Ty.pp ty)
    | Type_Not_Set n ->
        Some (Fmt.str "the type of the node %a was not set!" Node.pp n)
    | NoIdFound (s, n) ->
        Some (Fmt.str "get_id(%s) of %a: No Id found!" s Node.pp n)
    | Not_a_neighbour (kn, rn, k) ->
        Some
          (Fmt.str "%a was expected to be a neighbour of %a through %a" Node.pp
             rn Node.pp kn Node.pp k)
    | Empty_neighbour_set (kn, rn, k) ->
        Some
          (Fmt.str
             "%a was expected to be a neighbour of %a through %a, but %a's \
              neighbour set is empty"
             Node.pp rn Node.pp kn Node.pp k Node.pp kn)
    | _ -> None)

module STV = struct
  let ind_ty_var = Expr.Ty.Var.mk "ind_ty"
  let val_ty_var = Expr.Ty.Var.mk "val_ty"
  let alpha_ty_var = Expr.Ty.Var.mk "alpha"
  let a_ty_var = Expr.Ty.Var.mk "a"
  let b_ty_var = Expr.Ty.Var.mk "b"
  let c_ty_var = Expr.Ty.Var.mk "c"
  let ind_ty = Expr.Ty.of_var ind_ty_var
  let val_ty = Expr.Ty.of_var val_ty_var
  let a_ty = Expr.Ty.of_var a_ty_var
  let b_ty = Expr.Ty.of_var b_ty_var
  let c_ty = Expr.Ty.of_var c_ty_var
  let alpha_ty = Expr.Ty.of_var alpha_ty_var
  let array_ty = Expr.Ty.array ind_ty val_ty
  let array_ty_ab = Expr.Ty.array a_ty b_ty
  let array_ty_ac = Expr.Ty.array a_ty c_ty
  let array_ty_alpha = Expr.Ty.array ind_ty alpha_ty
  let term_of_var = Expr.Term.of_var
  let mk_index_var name = Expr.Term.Var.mk name ind_ty
  let mk_value_var name = Expr.Term.Var.mk name val_ty
  let mk_array_var name = Expr.Term.Var.mk name array_ty
  let vi = mk_index_var "i"
  let vj = mk_index_var "j"
  let vk = mk_index_var "k"
  let vv = mk_value_var "v"
  let ti = term_of_var vi
  let tj = term_of_var vj
  let tk = term_of_var vk
  let tv = term_of_var vv
  let va = mk_array_var "a"
  let vb = mk_array_var "b"
  let ta = term_of_var va
  let tb = term_of_var vb
end

let replicate n v = List.init n (fun _ -> v)
let mk_store_term = Expr.Term.Array.store
let mk_select_term = Expr.Term.Array.select
let apply_cst = Expr.Term.apply_cst

let array_ty_args : Expr.ty -> Expr.ty * Expr.ty = function
  | { ty_descr = TyApp ({ builtin = Expr.Array; _ }, [ ind_ty; val_ty ]); _ } ->
      (ind_ty, val_ty)
  | ty -> raise (Not_An_Array_ty ty)

let array_gty_args : Ground.Ty.t -> Ground.Ty.t * Ground.Ty.t = function
  | { app = { builtin = Expr.Array; _ }; args = [ ind_gty; val_gty ] } ->
      (ind_gty, val_gty)
  | ty -> raise (Not_An_Array_gty ty)

let get_node_ty env n =
  match Ground.Ty.S.elements (Ground.tys env n) with
  | h :: _ -> h
  | [] -> raise (Type_Not_Set n)

let get_array_gty env n =
  try
    List.find
      (function
        | Ground.Ty.{ app = { builtin = Expr.Array; _ }; _ } -> true
        | _ -> false)
      (Ground.Ty.S.elements (Ground.tys env n))
  with Not_found -> raise (Not_An_Array n)

let get_array_gty_args env n = array_gty_args (get_array_gty env n)

let add_array_gty env n ind_gty val_gty =
  Ground.add_ty env n (Ground.Ty.array ind_gty val_gty)

module Builtin = struct
  (** Additional array Builtins *)
  type _ Expr.t +=
    | Array_diff
          (** [Array_diff: 'a 'b. ('a, 'b) Array -> ('a, 'b) Array -> 'a] *)
    | Array_const  (** [Array_const: 'b. 'b -> (ind_ty, 'b) Array] *)
    | Array_map
          (** [Array_map: 'a 'b 'c. (('b -> ... -> 'b -> 'c) -> ('a, 'b) Array -> ... -> ('a, 'b) Array)-> ('a, 'c) Array]  *)
    | Array_default_index
          (** [Array_default_index: 'a 'b. ('a, 'b) Array -> 'a] *)
    | Array_default_value
          (** [Array_default_value: 'a 'b. ('a, 'b) Array -> 'b] *)

  let array_diff : Dolmen_std.Expr.term_cst =
    Expr.Id.mk ~name:"colibri2_array_diff" ~builtin:Array_diff
      (Dolmen_std.Path.global "colibri2_array_diff")
      (Expr.Ty.pi
         [ STV.a_ty_var; STV.b_ty_var ]
         (Expr.Ty.arrow [ STV.array_ty_ab; STV.array_ty_ab ] STV.a_ty))

  let array_const : Dolmen_std.Expr.term_cst =
    Expr.Id.mk ~name:"colibri2_array_const" ~builtin:Array_const
      (Dolmen_std.Path.global "colibri2_array_const")
      (Expr.Ty.pi
         [ STV.a_ty_var; STV.b_ty_var ]
         (Expr.Ty.arrow [ STV.b_ty ] STV.array_ty_ab))

  let array_map : int -> Dolmen_std.Expr.term_cst =
    let cache = DInt.H.create 13 in
    let get_ty i =
      match DInt.H.find cache i with
      | ty -> ty
      | exception Not_found ->
          let ty =
            Expr.Ty.arrow
              (Expr.Ty.arrow (replicate i STV.b_ty) STV.c_ty
              :: replicate i STV.array_ty_ab)
              STV.array_ty_ac
          in
          DInt.H.add cache i ty;
          ty
    in
    fun i ->
      Expr.Id.mk ~name:"colibri2_array_map" ~builtin:Array_map
        (Dolmen_std.Path.global "colibri2_array_map")
        (Expr.Ty.pi [ STV.a_ty_var; STV.b_ty_var; STV.c_ty_var ] (get_ty i))

  let array_default_index : Dolmen_std.Expr.term_cst =
    Expr.Id.mk ~name:"colibri2_array_default_index" ~builtin:Array_default_index
      (Dolmen_std.Path.global "colibri2_array_default_index")
      (Expr.Ty.pi
         [ STV.a_ty_var; STV.b_ty_var ]
         (Expr.Ty.arrow [ STV.array_ty_ab ] STV.a_ty))

  let array_default_value : Dolmen_std.Expr.term_cst =
    Expr.Id.mk ~name:"colibri2_array_default_value" ~builtin:Array_default_value
      (Dolmen_std.Path.global "colibri2_array_default_value")
      (Expr.Ty.pi
         [ STV.a_ty_var; STV.b_ty_var ]
         (Expr.Ty.arrow [ STV.array_ty_ab ] STV.b_ty))

  let apply_array_diff a b =
    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
    apply_cst array_diff [ ind_ty; val_ty ] [ a; b ]

  let apply_array_const v =
    apply_cst array_const
      [ Expr.Ty.of_var (Expr.Ty.Var.wildcard ()); v.Expr.term_ty ]
      [ v ]

  let apply_array_def_index a =
    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
    apply_cst array_default_index [ ind_ty; val_ty ] [ a ]

  let apply_array_def_value a =
    let ind_ty, val_ty = array_ty_args a.Expr.term_ty in
    apply_cst array_default_value [ ind_ty; val_ty ] [ a ]

  let apply_array_map f_arity f_term bitl =
    match (bitl, f_term) with
    | h :: _, Expr.{ term_ty = { ty_descr = Arrow (_, ret_ty); _ }; _ } ->
        let bi_ind_ty, bi_val_ty = array_ty_args h.Expr.term_ty in
        apply_cst (array_map f_arity)
          [ bi_ind_ty; bi_val_ty; ret_ty ]
          (f_term :: bitl)
    | _, _ ->
        failwith "array_map needs to be applied to a function and n > 0 arrays"

  let () =
    let app1 env s f =
      `Term
        (Dolmen_type.Base.term_app1
           (module Dolmen_loop.Typer.T)
           env s
           (fun a ->
             let ind_ty, val_ty = array_ty_args a.term_ty in
             apply_cst f [ ind_ty; val_ty ] [ a ]))
    in
    Expr.add_builtins (fun env s ->
        match s with
        | Dolmen_loop.Typer.T.Id
            { ns = Term; name = Simple "colibri2_array_diff" } ->
            `Term
              (Dolmen_type.Base.term_app2
                 (module Dolmen_loop.Typer.T)
                 env s
                 (fun a b ->
                   let ind_ty, val_ty = array_ty_args a.term_ty in
                   apply_cst array_diff [ ind_ty; val_ty ] [ a; b ]))
        | Dolmen_loop.Typer.T.Id
            { ns = Term; name = Simple "colibri2_array_const" } ->
            `Term
              (Dolmen_type.Base.term_app1
                 (module Dolmen_loop.Typer.T)
                 env s
                 (fun a -> apply_cst array_const [ a.term_ty ] [ a ]))
        | Dolmen_loop.Typer.T.Id
            { ns = Term; name = Simple "colibri2_array_default_index" } ->
            app1 env s array_default_index
        | Dolmen_loop.Typer.T.Id
            { ns = Term; name = Simple "colibri2_array_default_value" } ->
            app1 env s array_default_value
        | Dolmen_loop.Typer.T.Id
            { ns = Term; name = Simple "colibri2_array_map" } ->
            `Term
              (Dolmen_type.Base.term_app_list
                 (module Dolmen_loop.Typer.T)
                 env s
                 (function
                   | f_term :: t -> apply_array_map (List.length t) f_term t
                   | _ ->
                       failwith
                         "array_map needs to be applied to a function and n > \
                          0 arrays"))
        | _ -> `Not_found)
end

let sem_to_node s = Ground.node (Ground.index s)

let mk_subst term_l ty_l =
  Ground.Subst.
    { term = Expr.Term.Var.M.of_list term_l; ty = Expr.Ty.Var.M.of_list ty_l }

let ground_apply env cstr tyl nl =
  sem_to_node (Ground.apply env cstr tyl (IArray.of_list nl))

let mk_or env l = ground_apply env Expr.Term.Const.or_ [] l
let mk_and env l = ground_apply env Expr.Term.Const.and_ [] l

let mk_distinct env l gty =
  ground_apply env (Expr.Term.Const.distinct (List.length l)) [ gty ] l

let mk_select env a k ind_gty val_gty =
  ground_apply env Expr.Term.Const.Array.select [ ind_gty; val_gty ] [ a; k ]

let mk_store env a k v ind_gty val_gty =
  ground_apply env Expr.Term.Const.Array.store [ ind_gty; val_gty ] [ a; k; v ]

let mk_array_diff env a b ind_gty val_gty =
  let diff_ab =
    ground_apply env Builtin.array_diff [ ind_gty; val_gty ] [ a; b ]
  in
  if Egraph.is_registered env diff_ab then diff_ab
  else
    let diff_ba =
      ground_apply env Builtin.array_diff [ ind_gty; val_gty ] [ b; a ]
    in
    if Egraph.is_registered env diff_ba then diff_ba
    else (
      Egraph.register env diff_ab;
      diff_ab)

let mk_distinct_arrays env a b ind_gty val_gty =
  let diffn = mk_array_diff env a b ind_gty val_gty in
  mk_distinct env
    [
      mk_select env a diffn ind_gty val_gty;
      mk_select env b diffn ind_gty val_gty;
    ]
    (Ground.Ty.array ind_gty val_gty)

let mk_array_const env v val_gty =
  let wc_ty =
    Ground.Ty.convert Ground.Subst.empty.ty
      (Expr.Ty.of_var (Expr.Ty.Var.wildcard ()))
  in
  ground_apply env Builtin.array_const [ wc_ty; val_gty ] [ v ]

let distinct_arrays_term a b =
  let diff = Builtin.apply_array_diff a b in
  Expr.Term.distinct [ mk_select_term a diff; mk_select_term b diff ]

let do_mk_eq env a b =
  let eq = Equality.equality env [ a; b ] in
  Egraph.register env eq;
  Boolean.set_true env eq

let do_mk_eq_if_neq env a b =
  if not (Egraph.is_equal env a b) then do_mk_eq env a b

module type HTS = sig
  type key
  type t

  val set : Egraph.wt -> key -> t -> unit
  val find : Egraph.wt -> key -> t
  val find_opt : Egraph.wt -> key -> t option
  val change : f:(t option -> t option) -> Egraph.wt -> key -> unit
  val remove : Egraph.wt -> key -> unit
  val iter : f:(key -> t -> unit) -> Egraph.wt -> unit
  val fold : (key -> t -> 'a -> 'a) -> Egraph.wt -> 'a -> 'a
  val pp : Format.formatter -> Egraph.wt -> unit
end

module MkIHT (V : sig
  type t

  val pp : t Pp.pp
  val name : string
end) : HTS with type key = int and type t = V.t = struct
  type key = int
  type t = V.t

  let db = HT.create V.pp V.name
  let set (env : Egraph.wt) i v = HT.set db env i v
  let find (env : Egraph.wt) = HT.find db env
  let find_opt (env : Egraph.wt) = HT.find_opt db env
  let change ~f (env : Egraph.wt) i = HT.change f db env i
  let remove (env : Egraph.wt) i = HT.remove db env i
  let iter = HT.iter db
  let fold f env acc = HT.fold f db env acc

  let pp fmt env =
    Fmt.pf fmt "%s:[%a]" V.name
      (fun fmt () ->
        iter env ~f:(fun i v -> Fmt.pf fmt "(%d -> (%a));@ " i V.pp v))
      ()
end

module type IdDomSig = sig
  val register_merge_hook :
    Egraph.wt ->
    (Egraph.wt -> Node.t * int -> Node.t * int -> bool -> unit) ->
    unit
  val register_new_id_hook :
    Egraph.wt -> (Egraph.wt -> int -> Node.t -> unit) -> unit

  val set_id : Egraph.wt -> Node.t -> unit
  val get_id : Egraph.wt -> Node.t -> int
end

(** Global Id counter *)
let id_counter = ref 0

module MkIdDom (N : sig
  val name : string
end) : IdDomSig = struct
  let merge_hooks = Datastructure.Push.create Fmt.nop (N.name ^ ".merge_hooks")

  let register_merge_hook env (f : Egraph.wt -> 'a -> 'a -> bool -> unit) =
    Datastructure.Push.push merge_hooks env f

  module D = struct
    include Dom.Make (struct
      type t = DInt.t

      let equal = DInt.equal
      let pp = DInt.pp
      let is_singleton _ _ = None

      let key =
        Dom.Kind.create
          (module struct
            type nonrec t = t

            let name = N.name ^ ".DOM"
          end)

      let inter _ v _ = Some v
    end)

    let merge env (d1, n1) (d2, n2) b =
      (match (d1, n1, d2, n2) with
      | Some id1, n1, Some id2, n2 ->
          Datastructure.Push.iter merge_hooks env ~f:(fun f ->
              f env (n1, id1) (n2, id2) b)
      (* id1 always stays, b allows to determine which one will become the
         representative *)
      | _ ->
          (* Ideally, should be unreachable, but it is with the test
             "./colibri2/tests/solve/smt_array/unsat/ite1.smt2"? *)
          ());
      merge env (d1, n1) (d2, n2) b
  end

  let () = Dom.register (module D)
  let new_id_hooks = Datastructure.Push.create Fmt.nop (N.name ^ ".new_id_hooks")

  let register_new_id_hook env (f : Egraph.wt -> int -> Node.t -> unit) =
    Datastructure.Push.push new_id_hooks env f

  let set_id, get_id =
    let set_id env n =
      match Egraph.get_dom env D.key n with
      | None ->
          incr id_counter;
          Debug.dprintf4 debug "set_id(%s) of %a: none -> %d" N.name Node.pp n
            !id_counter;
          D.set_dom env n !id_counter;
          Datastructure.Push.iter new_id_hooks env ~f:(fun new_id_hook ->
              new_id_hook env !id_counter n)
      | Some id ->
          Debug.dprintf4 debug "set_id(%s) of %a: %d" N.name Node.pp n id
    in

    let get_id env n =
      match Egraph.get_dom env D.key n with
      | Some id -> id
      | None -> raise (NoIdFound (N.name, n))
    in
    (set_id, get_id)
end

module SHT (K : sig
  include Datatype

  val sort : t -> t
  val pp : t Pp.pp
end) (V : sig
  type t

  val name : string
  val pp : t Pp.pp
end) : HTS with type key = K.t and type t = V.t = struct
  type key = K.t
  type t = V.t

  module HT = Datastructure.Hashtbl (K)

  let db = HT.create V.pp V.name
  let aux f env k = f db env (K.sort k)
  let remove = aux HT.remove
  let set = aux HT.set
  let find = aux HT.find
  let find_opt = aux HT.find_opt
  let mem = aux HT.mem
  let change ~f = aux (HT.change f)
  let iter = HT.iter db
  let fold f env acc = HT.fold f db env acc

  let pp fmt env =
    Fmt.pf fmt "%s:[%a]" V.name
      (fun fmt () ->
        iter env ~f:(fun i v -> Fmt.pf fmt "(%a -> %a);@ " K.pp i V.pp v))
      ()
end

module I4 = struct
  module T = struct
    type t = DInt.t * DInt.t * DInt.t * DInt.t * Ground.Ty.t * Ground.Ty.t
    [@@deriving eq, ord, hash, show]
  end

  include T
  include MkDatatype (T)
end

module I3 = struct
  module T = struct
    type t = DInt.t * DInt.t * DInt.t * Ground.Ty.t * Ground.Ty.t
    [@@deriving eq, ord, hash, show]
  end

  include T
  include MkDatatype (T)
end

module I2 = struct
  module T = struct
    type t = DInt.t * DInt.t * Ground.Ty.t * Ground.Ty.t
    [@@deriving eq, ord, hash, show]
  end

  include T
  include MkDatatype (T)
end

module I1 = struct
  module T = struct
    type t = DInt.t * Ground.Ty.t * Ground.Ty.t [@@deriving eq, ord, hash, show]
  end

  include T
  include MkDatatype (T)
end

(* (id1,id2,id3,id4,ty1,ty2) *)
(* (id1,id2,ty1,ty2) *)

(* id1 -> (id2,id3,id4,ty1,ty2) *)
(* id1 -> (id2,ty1,ty2) *)