From 5d1d0821a9ba1385aadb605acc30dc4f8f769ac3 Mon Sep 17 00:00:00 2001
From: Virgile Prevosto <virgile.prevosto@m4x.org>
Date: Wed, 11 Apr 2018 17:12:29 +0200
Subject: [PATCH] [kernel] working -inline-calls

Tested with function having arguments and recursive functions
---
 .../ast_data/kernel_function.ml               |  60 ++--
 .../ast_data/kernel_function.mli              |   5 +
 src/kernel_services/ast_queries/cil.ml        |   1 +
 src/kernel_services/ast_queries/cil.mli       |   4 +
 .../ast_transformations/inline.ml             | 284 ++++++++----------
 tests/syntax/inline_calls.i                   |   9 +-
 tests/syntax/oracle/inline_calls.res.oracle   |  83 ++++-
 7 files changed, 246 insertions(+), 200 deletions(-)

diff --git a/src/kernel_services/ast_data/kernel_function.ml b/src/kernel_services/ast_data/kernel_function.ml
index b0ff28676c0..a1572643bd6 100644
--- a/src/kernel_services/ast_data/kernel_function.ml
+++ b/src/kernel_services/ast_data/kernel_function.ml
@@ -299,32 +299,44 @@ let find_first_stmt kf = match get_stmts kf with
 
 let () = Globals.find_first_stmt := find_first_stmt
 
-exception Found_label of stmt ref
-let find_label kf label =
+let label_table kf =
   match kf.fundec with
-  | Declaration _ -> raise Not_found
+  | Declaration _ -> Datatype.String.Map.empty
   | Definition (fundec,_) ->
-      let label_finder = object
-        inherit Cil.nopCilVisitor
-        method! vstmt s = begin
-          if List.exists
-            (fun lbl -> match lbl with
-             | Label (s,_,_) -> s = label
-             | Case _ -> false
-             | Default _ -> label="default")
-            s.labels then raise (Found_label (ref s));
-          Cil.DoChildren
-        end
-        method! vexpr _ = Cil.SkipChildren
-        method! vtype _ = Cil.SkipChildren
-        method! vinst _ = Cil.SkipChildren
-      end
-      in
-      try
-        ignore (Cil.visitCilFunction label_finder fundec);
-        (* Ok: this is not a code label *)
-        raise Not_found
-      with Found_label s -> s
+    let label_finder = object(self)
+      inherit Cil.nopCilVisitor
+      val mutable labels = Datatype.String.Map.empty
+      method all_labels = labels
+      method new_label stmt lbl =
+        match lbl with
+        | Label (l,_,_) ->
+          labels <- Datatype.String.Map.add l (ref stmt) labels
+        | Case _ -> ()
+        | Default _ ->
+          (* Kept for compatibility with old implementation of find_label,
+             but looks quite suspicious. *)
+          labels <- Datatype.String.Map.add "default" (ref stmt) labels
+
+      method! vstmt s =
+        List.iter (self#new_label s) s.labels;
+        Cil.DoChildren
+
+      method! vexpr _ = Cil.SkipChildren
+      method! vtype _ = Cil.SkipChildren
+      method! vinst _ = Cil.SkipChildren
+    end
+    in
+    ignore (Cil.visitCilFunction (label_finder:>Cil.cilVisitor) fundec);
+    label_finder#all_labels
+
+let find_all_labels kf =
+  let labels = label_table kf in
+  Datatype.String.(
+    Map.fold (fun lab _ acc -> Set.add lab acc) labels Set.empty)
+
+let find_label kf label =
+  let labels = label_table kf in
+  Datatype.String.Map.find label labels
 
 let get_called fct = match fct.enode with
   | Lval (Var vkf, NoOffset) -> 
diff --git a/src/kernel_services/ast_data/kernel_function.mli b/src/kernel_services/ast_data/kernel_function.mli
index 3a1708fbd63..1494940a75d 100644
--- a/src/kernel_services/ast_data/kernel_function.mli
+++ b/src/kernel_services/ast_data/kernel_function.mli
@@ -60,6 +60,11 @@ val find_label : t -> string -> stmt ref
   (** Find a given label in a kernel function.
       @raise Not_found if the label does not exist in the given function. *)
 
+val find_all_labels: t -> Datatype.String.Set.t
+  (** returns all labels present in a given function.
+      @since Frama-C+dev
+  *)
+
 val clear_sid_info: unit -> unit
 (** removes any information related to statements in kernel functions.
     ({i.e.} the table used by the function below).
diff --git a/src/kernel_services/ast_queries/cil.ml b/src/kernel_services/ast_queries/cil.ml
index a4eaf8df883..ead4d3b8f06 100644
--- a/src/kernel_services/ast_queries/cil.ml
+++ b/src/kernel_services/ast_queries/cil.ml
@@ -6196,6 +6196,7 @@ let need_cast ?(force=false) oldt newt =
  (* Make a local variable and add it to a function *)
  let makeLocalVar fdec ?scope ?(temp=false) ?(insert = true) name typ =
    let vi = makeLocal ~temp fdec name typ in
+   refresh_local_name fdec vi;
    if insert then
      begin
        fdec.slocals <- fdec.slocals @ [vi];
diff --git a/src/kernel_services/ast_queries/cil.mli b/src/kernel_services/ast_queries/cil.mli
index 42bf2d7429b..2ef0b85129e 100644
--- a/src/kernel_services/ast_queries/cil.mli
+++ b/src/kernel_services/ast_queries/cil.mli
@@ -663,6 +663,10 @@ val makeFormalVar: fundec -> ?where:string -> string -> typ -> varinfo
     Make sure you know what you are doing if you set [insert=false].
     [temp] is passed to {!Cil.makeVarinfo}.
     The variable is attached to the toplevel block if [scope] is not specified.
+    If the name passed as argument already exists within the function,
+    a fresh name will be generated for the varinfo.
+
+    @modify Frama-C+dev the name of the variable is guaranteed to be fresh.
 *)
 val makeLocalVar:
   fundec -> ?scope:block -> ?temp:bool -> ?insert:bool
diff --git a/src/kernel_services/ast_transformations/inline.ml b/src/kernel_services/ast_transformations/inline.ml
index a063888339b..1f73a5603af 100644
--- a/src/kernel_services/ast_transformations/inline.ml
+++ b/src/kernel_services/ast_transformations/inline.ml
@@ -34,149 +34,75 @@ module InlineCalls =
     let help = "inline calls to functions f1, ..., fn"
   end)
 
-module InlineCounters =
-  State_builder.Hashtbl
-    (Cil_datatype.Varinfo.Hashtbl)
-    (Datatype.Int)
-    (struct
-      let name = "Inline.InlineCounters"
-      let dependencies = [Ast.self]
-      let size = 1
-    end)
-
-let get_and_incr_inline_counter vi =
-  let c = try InlineCounters.find vi with Not_found -> 0 in
-  InlineCounters.replace vi (c+1);
-  c+1
-
-let inline_call loc caller callee block stmt kind return args =
+let inline_call loc caller callee return args =
   let caller_fd = Kernel_function.get_definition caller in
-  let callee_fd = Kernel_function.get_definition callee in
-  let args = match kind with
-    | None | Some Plain_func -> args
-    | Some Constructor -> Cil.mkAddrOf ~loc (Extlib.the return) :: args
+  let caller_labels = ref (Kernel_function.find_all_labels caller) in
+  let fresh_label lab =
+    let (_,lab) =
+      Extlib.make_unique_name
+        (fun x -> Datatype.String.Set.mem x !caller_labels)
+        ~sep:"_" ~start:0 lab
+    in
+    caller_labels:= Datatype.String.Set.add lab !caller_labels; lab
   in
-  let ret_val = ref None in
   let o = object(self)
     inherit Visitor.frama_c_refresh (Project.current ())
 
-    val mutable toplevel_block = true;
-    method private mk_inline_args =
-      let mk_local vi =
-        let vi' = Cil.makeLocalVar caller_fd ~temp:true vi.vname vi.vtype in
-        Cil.set_varinfo self#behavior vi vi';
-        Cil.set_orig_varinfo self#behavior vi' vi;
-        vi'
-      in
-      (* Formals become locals of the new block. They are initialized to the
-         corresponding argument. *)
-      try
-        List.fold_left2
-          (fun (vars, inits) vi exp ->
-             let vi' = mk_local vi in
-             let init =
-               Cil.mkStmtOneInstr ~valid_sid:true
-                 (Set ((Var vi', NoOffset), exp, exp.eloc))
-             in
-             vi' :: vars, init :: inits)
-          ([], [])
-          callee_fd.sformals
-          args
-      with Invalid_argument _ ->
-        Kernel.fatal "inliner: undetected variadic function call"
+    method! vvdec _ =
+      Cil.DoChildrenPost (fun vi -> Cil.refresh_local_name caller_fd vi; vi)
 
-    method !vblock _ =
-      Cil.DoChildrenPost
-        (fun blk ->
-           if toplevel_block then begin
-             (* Adds initialization of formals only for the main block. *)
-             let vars, initializers = self#mk_inline_args in
-             blk.blocals <- vars @ blk.blocals;
-             blk.bstmts <- initializers @ blk.bstmts;
-             toplevel_block <- false;
-           end;
-           (* ensure there's no name collision between
-              variables of callee and caller *)
-           List.iter (Cil.refresh_local_name caller_fd) blk.blocals;
-           caller_fd.slocals <- caller_fd.slocals @ blk.blocals;
-           blk)
+    method! vvrbl v =
+      if v.vglob then
+        Cil.ChangeTo (Cil.get_original_varinfo self#behavior v)
+      else Cil.DoChildren
 
-  method !vstmt_aux _ =
-    Cil.DoChildrenPost
-      (fun stmt ->
-         (* Replace return by an assignment; or remove it if useless *)
-         match stmt.skind with
-         | Return(exp, loc) ->
-           let instr = match return, exp, kind with
-             | None, _, _ | _, None, _ | _, _, Some Constructor ->
-               (* ignore this return *) Skip loc
-             | Some ret, Some exp, None -> Set(ret, exp, loc)
-             | Some _, Some exp, Some Plain_func ->
-               let rv =
-                 match !ret_val with
-                 | None ->
-                   let vi =
-                     Cil.makeTempVar
-                       caller_fd
-                       ~insert:false ~name:"__inline_res"
-                       ~descr:"exported result of inlined function"
-                       (Cil.typeOf exp)
-                   in
-                   ret_val := Some vi; vi
-                 | Some vi -> vi
-               in
-               Set (Cil.var rv, exp, loc)
+    method! vfunc _ =
+      Cil.DoChildrenPost
+        (fun fd ->
+           caller_fd.slocals <-
+             caller_fd.slocals @ fd.sformals @ fd.slocals;
+           let add_init vi arg =
+             vi.vdefined <- true;
+             Cil.mkStmtOneInstr
+               (Local_init (vi,AssignInit (SingleInit arg),loc))
            in
-           stmt.skind <- Instr instr;
-           stmt
-         | _ -> stmt)
-  end
-  in
-  o#set_current_kf callee;
-  let refresh_var vi = ignore (Visitor.visitFramacVarDecl o vi) in
-  List.iter refresh_var (Kernel_function.get_locals callee);
-  Cil.set_kernel_function o#behavior callee caller;
-  Cil.set_orig_kernel_function o#behavior caller callee;
-  let blk = Visitor.visitFramacBlock o callee_fd.sbody in
-  match !ret_val with
-  | None -> stmt.skind <- Block blk
-  | Some vi ->
-    caller_fd.slocals <- vi :: caller_fd.slocals;
-    block.blocals <- vi :: block.blocals;
-    let stmt_inlined = Cil.mkStmt ~valid_sid:true (Block blk) in
-    let instr_res =
-      match kind, return with
-      | None, None -> Skip loc
-      | None, Some lv -> Set (lv, Cil.evar ~loc vi, loc)
-      | Some Plain_func, Some (Var res, NoOffset) ->
-        Local_init(res, AssignInit (SingleInit (Cil.evar ~loc vi)), loc)
-      | _ -> Kernel.fatal "Unexpected result of inlining function"
-    in
-    let block =
-      Cil.mkBlock [stmt_inlined; Cil.mkStmtOneInstr ~valid_sid:true instr_res]
-    in
-    let block = Cil.transient_block block in
-    stmt.skind <- Block block
+           let inits = List.map2 add_init fd.sformals args in
+           fd.sbody.blocals <- fd.sformals @ fd.sbody.blocals;
+           fd.sbody.bstmts <- inits @ fd.sbody.bstmts;
+           fd)
 
-let is_specified_enough kf =
-  (* ok if at least one "requires" and one "ensures" *)
-  let got iter bhv =
-    try
-      iter (fun _ _p -> raise Exit) kf bhv.b_name;
-      false
-    with Exit ->
-      true
+    method !vstmt_aux _ =
+      Cil.DoChildrenPost
+        (fun stmt ->
+           stmt.labels <-
+             List.map
+               (function
+                 | Label (s,l,f) -> Label (fresh_label s,l,f)
+                 | (Case _ | Default _) as lab -> lab)
+               stmt.labels;
+           (* Replace return by an assignment; or remove it if useless *)
+           (match stmt.skind with
+            | Return(exp, loc) ->
+              let skind =
+                match return, exp with
+                | None, None  -> Instr (Skip loc)
+                | None, Some exp ->
+                  (* Keep the expression in case it could lead to an alarm *)
+                  (Cil.mkPureExpr ~fundec:caller_fd exp).skind
+                | Some ret, Some exp -> Instr (Set(ret, exp, loc))
+                | Some _, None ->
+                  Kernel.fatal
+                    "trying to assign the result of a void returning function"
+              in
+              stmt.skind <- skind
+            | _ -> ());
+           stmt)
+  end
   in
-  let req, ens, ass =
-    Annotations.fold_behaviors
-      (fun _ bhv (req, ens, ass) ->
-         req || got Annotations.iter_requires bhv,
-         ens || got Annotations.iter_ensures bhv,
-         ass || got Annotations.iter_assigns bhv)
-      kf
-      (false, false, false)
+  let callee_fd =
+    Visitor.visitFramacFunction o (Kernel_function.get_definition callee)
   in
-  req && ens && ass
+  callee_fd.sbody
 
 let is_variadic_function vi = match vi.vtype with
   | TFun(_, _, is_v, _) -> is_v
@@ -185,39 +111,83 @@ let is_variadic_function vi = match vi.vtype with
 let inliner functions_to_inline = object (self)
   inherit Visitor.frama_c_inplace
 
-  val blocks = Stack.create ()
+  val call_stack = Stack.create ()
 
-  method! vblock b =
-    Stack.push b blocks;
-    Cil.DoChildrenPost (fun b -> ignore (Stack.pop blocks); b)
+  method private recursive_call_limit kf =
+    let nb_calls =
+      Stack.fold
+        (fun res kf' -> if Cil_datatype.Kf.equal kf kf' then res + 1 else res)
+        0 call_stack
+    in
+    nb_calls >= 1 (* TODO: be more flexible. *)
 
   (* inline the given [stmt], which must be a call, in the given [caller] *)
   method private inline stmt init_kind return f args =
-    if is_variadic_function f then begin
-      Kernel.warning ~current:true ~once:true
-        "variadic function call in loop detected: \
-         results might be unprecise"
-    end;
     let callee =
       try Globals.Functions.get f
       with Not_found ->
           Kernel.fatal
             ~current:true "Expecting a function, got %a" Printer.pp_varinfo f
     in
-    if Kernel_function.Set.mem callee functions_to_inline then
-      if not (is_specified_enough callee) then begin
-        let counter = get_and_incr_inline_counter f in
-        let max_inline = (*TODO*) 1 in
-        if counter <= max_inline then
-          inline_call
-            (Cil_datatype.Stmt.loc stmt)
-            (Extlib.the self#current_kf)
-            callee (Stack.top blocks) stmt init_kind return args
-        else
-          Kernel.feedback ~dkey ~once:true
-            "reached inline limit for '%s'" f.vname
-      end;
-    Cil.DoChildren
+    if Kernel_function.Set.mem callee functions_to_inline &&
+       not (self#recursive_call_limit callee)
+    then begin
+      if is_variadic_function f then begin
+        Kernel.warning ~current:true ~once:true
+        "Ignoring inlining option for variadic function %a"
+        Printer.pp_varinfo f;
+        Cil.DoChildren
+      end else
+        begin
+          Stack.push callee call_stack;
+          let loc = Cil_datatype.Stmt.loc stmt in
+          let needs_assign, return_aux, args =
+            match init_kind, return with
+            | None, _ -> false, return, args
+            | Some Plain_func, Some lv ->
+              let t = Cil.typeOfLval lv in
+              let scope = Kernel_function.find_enclosing_block stmt in
+              let v =
+                Cil.makeLocalVar
+                  (Extlib.the self#current_func) ~scope ~temp:true "__inline_tmp" t
+              in
+              true, Some (Cil.var v), args
+            | Some Constructor, Some (Var r, NoOffset) ->
+              (* Inlining will prevent r to be syntactically seen as initialized
+                 or const: *)
+              r.vdefined <- false;
+              r.vtype <- Cil.typeRemoveAttributes ["const"] r.vtype;
+              false, None, (Cil.mkAddrOf loc (Cil.var r)) :: args
+            | Some _, _ ->
+              Kernel.fatal "Attempt to initialize an inexistent varinfo"
+          in
+          let block =
+            inline_call
+              (Cil_datatype.Stmt.loc stmt)
+              (Extlib.the self#current_kf)
+              callee return_aux args
+          in
+          let skind =
+            if needs_assign then begin
+              match return_aux, return with
+              | Some (Var aux, NoOffset), Some (Var r, NoOffset) ->
+                let b =
+                  Cil.mkBlockNonScoping [
+                    Cil.mkStmt (Block block);
+                    Cil.mkStmtOneInstr
+                      (Local_init
+                         (r, AssignInit (SingleInit (Cil.evar ~loc aux)),loc))]
+                in
+                Block b
+              | _ ->
+                Kernel.fatal
+                  "Unexpected lval during inlining of a local initializer"
+            end else Block block
+          in
+          stmt.skind <- skind;
+          Cil.DoChildrenPost (fun stmt -> ignore (Stack.pop call_stack); stmt);
+        end
+    end else Cil.DoChildren
 
   method !vstmt_aux stmt =
     match stmt.skind with
@@ -232,10 +202,6 @@ let inliner functions_to_inline = object (self)
       self#inline stmt (Some kind) (Some (Cil.var v)) f args
     | _ -> Cil.DoChildren
 
-  method !vfunc _kf =
-    InlineCounters.clear ();
-    Cil.DoChildren
-
 end
 
 let inline_calls ast =
diff --git a/tests/syntax/inline_calls.i b/tests/syntax/inline_calls.i
index e980dc59de9..f3c16e14f74 100644
--- a/tests/syntax/inline_calls.i
+++ b/tests/syntax/inline_calls.i
@@ -1,5 +1,5 @@
 /* run.config
-   STDOPT: +"-inline-calls f,g,h,i,in_f"
+   STDOPT: +"-inline-calls f,g,h,i,in_f,rec"
 
  */
 
@@ -22,10 +22,17 @@ int h() {
 }
 
 int i() {
+  /*@ assert i:\true; */
   return 0;
 }
 
+int rec(int x) {
+  if (x < 0) return x;
+  return rec(x-1);
+}
+
 int main() {
   int local_init = i();
+  int t = rec(local_init);
   return h();
 }
diff --git a/tests/syntax/oracle/inline_calls.res.oracle b/tests/syntax/oracle/inline_calls.res.oracle
index 15d31c66c97..cc21d9941d8 100644
--- a/tests/syntax/oracle/inline_calls.res.oracle
+++ b/tests/syntax/oracle/inline_calls.res.oracle
@@ -74,24 +74,75 @@ int h(void)
 int i(void)
 {
   int __retres;
+  /*@ assert i: \true; */ ;
   __retres = 0;
   return __retres;
 }
 
+int rec(int x)
+{
+  int __retres;
+  int tmp;
+  if (x < 0) {
+    __retres = x;
+    goto return_label;
+  }
+  {
+    int __retres_6;
+    int tmp_5;
+    int x_7 = x - 1;
+    if (x_7 < 0) {
+      __retres_6 = x_7;
+      goto return_label_0;
+    }
+    tmp_5 = rec(x_7 - 1);
+    __retres_6 = tmp_5;
+    return_label_0: tmp = __retres_6;
+  }
+  __retres = tmp;
+  return_label: return __retres;
+}
+
 int main(void)
 {
-  int __inline_res;
-  int tmp_0;
+  int __inline_tmp_8;
+  int __inline_tmp;
+  int tmp_1;
   {
     int __retres;
+    /*@ assert i: \true; */ ;
     __retres = 0;
-    __inline_res = __retres;
+    __inline_tmp = __retres;
   }
-  int local_init = __inline_res;
+  int local_init = __inline_tmp;
   {
+    int __retres_10;
     int tmp;
+    int x = local_init;
+    if (x < 0) {
+      __retres_10 = x;
+      goto return_label;
+    }
+    {
+      int __retres_6;
+      int tmp_5;
+      int x_7 = x - 1;
+      if (x_7 < 0) {
+        __retres_6 = x_7;
+        goto return_label_0;
+      }
+      tmp_5 = rec(x_7 - 1);
+      __retres_6 = tmp_5;
+      return_label_0: tmp = __retres_6;
+    }
+    __retres_10 = tmp;
+    return_label: __inline_tmp_8 = __retres_10;
+  }
+  int t = __inline_tmp_8;
+  {
+    int tmp_11;
     {
-      int __retres_7;
+      int __retres_12;
       if (v) {
         int tmp_3;
         {
@@ -99,24 +150,24 @@ int main(void)
           __retres_5 = 2;
           tmp_3 = __retres_5;
         }
-        __retres_7 = tmp_3;
-        goto return_label;
+        __retres_12 = tmp_3;
+        goto return_label_1;
       }
       else {
-        int tmp_0_6;
+        int tmp_0;
         {
-          int __retres_6;
-          __retres_6 = 3;
-          tmp_0_6 = __retres_6;
+          int __retres_6_13;
+          __retres_6_13 = 3;
+          tmp_0 = __retres_6_13;
         }
-        __retres_7 = tmp_0_6;
-        goto return_label;
+        __retres_12 = tmp_0;
+        goto return_label_1;
       }
-      return_label: tmp = __retres_7;
+      return_label_1: tmp_11 = __retres_12;
     }
-    tmp_0 = tmp;
+    tmp_1 = tmp_11;
   }
-  return tmp_0;
+  return tmp_1;
 }
 
 
-- 
GitLab