From 4ba0e4123c1385fbab1b6a8b329040bdb1293414 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Loi=CC=88c=20Correnson?= <loic.correnson@cea.fr>
Date: Fri, 21 Jun 2024 09:40:13 +0200
Subject: [PATCH] [kernel] fix import module lookup strategy

---
 .../ast_queries/logic_typing.ml               | 108 ++++++++++--------
 .../ast_queries/logic_typing.mli              |   2 +-
 .../ast_queries/logic_utils.ml                |   3 +
 .../ast_queries/logic_utils.mli               |   3 +
 tests/spec/module.i                           |   2 +-
 tests/spec/oracle/module.res.oracle           |   2 +-
 6 files changed, 69 insertions(+), 51 deletions(-)

diff --git a/src/kernel_services/ast_queries/logic_typing.ml b/src/kernel_services/ast_queries/logic_typing.ml
index 07dcb4ebc9..376d046580 100644
--- a/src/kernel_services/ast_queries/logic_typing.ml
+++ b/src/kernel_services/ast_queries/logic_typing.ml
@@ -660,7 +660,7 @@ sig
     Cil_types.location -> Logic_ptree.relation option ->
     Cil_types.term -> Cil_types.term -> Cil_types.logic_type
 
-  val add_import : ?alias:string -> string -> unit
+  val add_import : ?current:bool -> ?alias:string -> string -> unit
   val clear_imports : unit -> unit
   val push_imports : unit -> unit
   val pop_imports : unit -> unit
@@ -731,62 +731,74 @@ struct
       s.long_prefix s.short_prefix
   [@@ warning "-32"]
 
-  let scopes : scope list Stack.t = Stack.create ()
-  let imported : scope list ref = ref []
+  let scopes : (scope option * scope list) Stack.t = Stack.create ()
 
-  let open_scope ~name ?alias () =
-    let short = match alias with Some a -> a | None ->
-      List.hd @@ List.rev @@ String.split_on_char ':'  name
-    in imported := {
-      long_prefix = name ^ "::";
-      short_prefix = short ^ "::";
-    } :: !imported
+  let current_scope : scope option ref = ref None
+  let imported_scopes : scope list ref = ref []
 
-  let clear_imports () = Stack.clear scopes ; imported := []
-  let push_imports () = Stack.push !imported scopes
-  let pop_imports () = imported := Stack.pop scopes
+  let current_scopes () =
+    match !current_scope with
+    | None -> !imported_scopes
+    | Some s -> s :: !imported_scopes
 
-  let add_import ?alias name =
-    match alias with
-    | Some _ -> open_scope ~name ?alias ()
-    | None ->
-      begin
-        match List.rev @@ String.split_on_char ':' name with
-        | alias::_ -> open_scope ~name ~alias ()
-        | [] -> open_scope ~name ()
-      end
+  let clear_imports () =
+    begin
+      Stack.clear scopes ;
+      current_scope := None ;
+      imported_scopes := [] ;
+    end
+  let push_imports () =
+    Stack.push (!current_scope,!imported_scopes) scopes
+  let pop_imports () =
+    begin
+      let c,s = Stack.pop scopes in
+      current_scope := c ;
+      imported_scopes := s ;
+    end
 
-  let find_import find a =
-    let xs = String.split_on_char ':' a in
-    let n = List.length xs in
-    if n = 1 then (* unqualified name *)
-      match List.find_map (fun s -> find (s.long_prefix ^ a)) !imported with
-      | Some _ as result -> result
-      | None -> find a
-    else
-    if n = 3 then (* single module qualified name *)
-      let is_short s = String.starts_with ~prefix:s.short_prefix a in
-      match List.find_opt is_short !imported with
+  let add_import ?(current=false) ?alias name =
+    begin
+      let short = match alias with Some a -> a | None ->
+        List.hd @@ List.rev @@ Logic_utils.longident name in
+      let s = {
+        long_prefix = name ^ "::";
+        short_prefix = short ^ "::";
+      } in
+      if current then
+        current_scope := Some s
+      else
+        imported_scopes := s :: !imported_scopes ;
+    end
+
+  let find_import fn a =
+    let find_opt b = try Some (fn b) with Not_found -> None in
+    if Logic_utils.is_qualified a then
+      let in_scope s = String.starts_with ~prefix:s.short_prefix a in
+      find_opt @@
+      match List.find_opt in_scope @@ current_scopes () with
+      | None -> a
       | Some s ->
-        let x = List.hd @@ List.rev xs in
-        find (s.long_prefix ^ x)
-      | None -> find a
-    else (* long qualified name *) find a
+        let p = String.length s.short_prefix in
+        let n = String.length a in
+        s.long_prefix ^ String.sub a p (n-p)
+    else
+      let find_in_scope s = find_opt (s.long_prefix ^ a) in
+      match List.find_map find_in_scope @@ current_scopes () with
+      | None -> find_opt a
+      | Some _ as result -> result
 
   let resolve_ltype =
-    find_import
-      begin fun t ->
-        try Some (Logic_env.find_logic_type t)
-        with Not_found -> None
-      end
+    find_import Logic_env.find_logic_type
 
   let resolve_lapp f env =
     try Some (Lfun [Lenv.find_logic_info f env]) with Not_found ->
       find_import
         begin fun a ->
-          try Some (Ctor (Logic_env.find_logic_ctor a)) with Not_found ->
-          match Logic_env.find_all_logic_functions a with
-          | [] -> None | ls -> Some (Lfun ls)
+          try
+            Ctor (Logic_env.find_logic_ctor a)
+          with Not_found ->
+            let ls = Logic_env.find_all_logic_functions a in
+            if ls <> [] then Lfun ls else raise Not_found
         end f
 
   let rollback = Queue.create ()
@@ -4309,21 +4321,21 @@ struct
           "Duplicated module %s (first occurrence at %a)"
           id Cil_printer.pp_location oldloc in
       push_imports () ;
-      open_scope ~name () ;
+      add_import ~current:true name ;
       let l = List.filter_map (decl ~context) decls in
       pop_imports () ;
       ignore (Logic_env.Modules.memo ~change (fun _ -> loc) name);
       Some (Dmodule(name,l,[],loc))
 
     | LDimport(None,name,alias) ->
-      open_scope ~name ?alias () ; None
+      add_import ?alias name ; None
 
     | LDimport(Some driver,name,alias) ->
       let decls = ref [] in
       let builder = make_module_builder decls name in
       let path = Logic_utils.longident name in
       Extensions.importer driver ~builder ~loc path ;
-      open_scope ~name ?alias () ;
+      add_import ?alias name ;
       Some (Dmodule(name,List.rev !decls,[],loc))
 
     | LDtype(name,l,def) ->
diff --git a/src/kernel_services/ast_queries/logic_typing.mli b/src/kernel_services/ast_queries/logic_typing.mli
index de61bec34f..7726e105af 100644
--- a/src/kernel_services/ast_queries/logic_typing.mli
+++ b/src/kernel_services/ast_queries/logic_typing.mli
@@ -182,7 +182,7 @@ sig
     Cil_types.term -> Cil_types.term -> Cil_types.logic_type
 
   (** Open module in local environment. *)
-  val add_import : ?alias:string -> string -> unit
+  val add_import : ?current:bool -> ?alias:string -> string -> unit
   val clear_imports : unit -> unit
   val push_imports : unit -> unit
   val pop_imports : unit -> unit
diff --git a/src/kernel_services/ast_queries/logic_utils.ml b/src/kernel_services/ast_queries/logic_utils.ml
index cf42d5119f..0ec2e482ea 100644
--- a/src/kernel_services/ast_queries/logic_utils.ml
+++ b/src/kernel_services/ast_queries/logic_utils.ml
@@ -870,6 +870,9 @@ let is_same_builtin_profile l1 l2 =
   is_same_list (fun (_,t1) (_,t2) -> is_same_type t1 t2)
     l1.bl_profile l2.bl_profile
 
+let is_qualified a =
+  try ignore @@ String.index a ':' ; true with Not_found -> false
+
 let longident = Str.split @@ Str.regexp_string "::"
 
 let mem_logic_function f =
diff --git a/src/kernel_services/ast_queries/logic_utils.mli b/src/kernel_services/ast_queries/logic_utils.mli
index 5d0f400732..43a080b0f4 100644
--- a/src/kernel_services/ast_queries/logic_utils.mli
+++ b/src/kernel_services/ast_queries/logic_utils.mli
@@ -34,6 +34,9 @@ exception Not_well_formed of location * string
 (** exception raised when an unknown extension is called. *)
 exception Unknown_ext
 
+(** Test if the given string contains ':' (long-identifiers). *)
+val is_qualified : string -> bool
+
 (** Split a long-identifier into the list of its components.
     eg. ["A::B::(<:)"] is split into [["A";"B";"(<:)"]].
     Returns a singleton for regular identifiers.
diff --git a/tests/spec/module.i b/tests/spec/module.i
index 28f9198080..80f34a59aa 100644
--- a/tests/spec/module.i
+++ b/tests/spec/module.i
@@ -12,7 +12,7 @@
   module foo::bar {
     import Foo \as X;
     logic t inv(X::t x);
-    logic t opN(t x, integer n) = n >= 0 ? X::opN(x,n) : X::opN(inv(x),-n);
+    logic t opN(t x, integer n) = n >= 0 ? X::opN(x,n) : opN(inv(x),-n);
   }
   import Foo \as A;
   import foo::bar \as B;
diff --git a/tests/spec/oracle/module.res.oracle b/tests/spec/oracle/module.res.oracle
index 5f7e65ce72..0f5fca9d76 100644
--- a/tests/spec/oracle/module.res.oracle
+++ b/tests/spec/oracle/module.res.oracle
@@ -17,7 +17,7 @@ module foo::bar {
   logic Foo::t inv(Foo::t x) ;
   
   logic Foo::t opN(Foo::t x, ℤ n) =
-    n ≥ 0? Foo::opN(x, n): Foo::opN(inv(x), -n);
+    n ≥ 0? Foo::opN(x, n): opN(inv(x), -n);
   
   }
  */
-- 
GitLab