From e71eaed450efb763b0acac31f9dc91c201f36c5f Mon Sep 17 00:00:00 2001 From: Matthew Simpson Date: Wed, 6 Dec 2017 21:22:54 +0000 Subject: [PATCH] [PGO] Make indirect call promotion a utility This patch factors out the main code transformation utilities in the pgo-driven indirect call promotion pass and places them in Transforms/Utils. The change is intended to be a non-functional change, letting non-pgo-driven passes share a common implementation with the existing pgo-driven pass. The common utilities are used to conditionally promote indirect call sites to direct call sites. They perform the underlying transformation, and do not consider profile information. The pgo-specific details (e.g., the computation of branch weight metadata) have been left in the indirect call promotion pass. Differential Revision: https://reviews.llvm.org/D40658 llvm-svn: 319963 --- include/llvm/Transforms/Instrumentation.h | 10 +- .../Transforms/Utils/CallPromotionUtils.h | 44 +++ lib/Transforms/IPO/SampleProfile.cpp | 7 +- .../Instrumentation/IndirectCallPromotion.cpp | 323 +---------------- lib/Transforms/Utils/CMakeLists.txt | 1 + lib/Transforms/Utils/CallPromotionUtils.cpp | 328 ++++++++++++++++++ 6 files changed, 394 insertions(+), 319 deletions(-) create mode 100644 include/llvm/Transforms/Utils/CallPromotionUtils.h create mode 100644 lib/Transforms/Utils/CallPromotionUtils.cpp diff --git a/include/llvm/Transforms/Instrumentation.h b/include/llvm/Transforms/Instrumentation.h index 0d76328a2f8..be074b0d84c 100644 --- a/include/llvm/Transforms/Instrumentation.h +++ b/include/llvm/Transforms/Instrumentation.h @@ -77,9 +77,12 @@ ModulePass *createPGOIndirectCallPromotionLegacyPass(bool InLTO = false, bool SamplePGO = false); FunctionPass *createPGOMemOPSizeOptLegacyPass(); -// Helper function to check if it is legal to promote indirect call \p Inst -// to a direct call of function \p F. Stores the reason in \p Reason. -bool isLegalToPromote(Instruction *Inst, Function *F, const char **Reason); +// The pgo-specific indirect call promotion function declared below is used by +// the pgo-driven indirect call promotion and sample profile passes. It's a +// wrapper around llvm::promoteCall, et al. that additionally computes !prof +// metadata. We place it in a pgo namespace so it's not confused with the +// generic utilities. +namespace pgo { // Helper function that transforms Inst (either an indirect-call instruction, or // an invoke instruction , to a conditional call to F. This is like: @@ -98,6 +101,7 @@ Instruction *promoteIndirectCall(Instruction *Inst, Function *F, uint64_t Count, uint64_t TotalCount, bool AttachProfToDirectCall, OptimizationRemarkEmitter *ORE); +} // namespace pgo /// Options for the frontend instrumentation based profiling pass. struct InstrProfOptions { diff --git a/include/llvm/Transforms/Utils/CallPromotionUtils.h b/include/llvm/Transforms/Utils/CallPromotionUtils.h new file mode 100644 index 00000000000..e0bf85781d8 --- /dev/null +++ b/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -0,0 +1,44 @@ +//===- CallPromotionUtils.h - Utilities for call promotion ------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities useful for promoting indirect call sites to +// direct call sites. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H +#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H + +#include "llvm/IR/CallSite.h" + +namespace llvm { + +/// Return true if the given indirect call site can be made to call \p Callee. +/// +/// This function ensures that the number and type of the call site's arguments +/// and return value match those of the given function. If the types do not +/// match exactly, they must at least be bitcast compatible. If \p FailureReason +/// is non-null and the indirect call cannot be promoted, the failure reason +/// will be stored in it. +bool isLegalToPromote(CallSite CS, Function *Callee, + const char **FailureReason = nullptr); + +/// Promote the given indirect call site to conditionally call \p Callee. +/// +/// This function creates an if-then-else structure at the location of the call +/// site. The original call site is promoted and moved into the "then" block. A +/// clone of the indirect call site is placed in the "else" block and returned. +/// If \p BranchWeights is non-null, it will be used to set !prof metadata on +/// the new conditional branch. +Instruction *promoteCallWithIfThenElse(CallSite CS, Function *Callee, + MDNode *BranchWeights = nullptr); + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 8930e9b2b95..f0e781b9d92 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -69,6 +69,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include #include @@ -823,10 +824,10 @@ bool SampleProfileLoader::inlineHotFunctions( if (R != SymbolMap.end() && R->getValue() && !R->getValue()->isDeclaration() && R->getValue()->getSubprogram() && - isLegalToPromote(I, R->getValue(), &Reason)) { + isLegalToPromote(CallSite(I), R->getValue(), &Reason)) { uint64_t C = FS->getEntrySamples(); - Instruction *DI = promoteIndirectCall( - I, R->getValue(), C, Sum, false, ORE); + Instruction *DI = + pgo::promoteIndirectCall(I, R->getValue(), C, Sum, false, ORE); Sum -= C; PromotedInsns.insert(I); // If profile mismatches, we should not attempt to inline DI. diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 8b9bbb49955..49b8a67a6c1 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -47,6 +47,7 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include #include #include @@ -214,49 +215,6 @@ public: } // end anonymous namespace -bool llvm::isLegalToPromote(Instruction *Inst, Function *F, - const char **Reason) { - // Check the return type. - Type *CallRetType = Inst->getType(); - if (!CallRetType->isVoidTy()) { - Type *FuncRetType = F->getReturnType(); - if (FuncRetType != CallRetType && - !CastInst::isBitCastable(FuncRetType, CallRetType)) { - if (Reason) - *Reason = "Return type mismatch"; - return false; - } - } - - // Check if the arguments are compatible with the parameters - FunctionType *DirectCalleeType = F->getFunctionType(); - unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); - CallSite CS(Inst); - unsigned ArgNum = CS.arg_size(); - - if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) { - if (Reason) - *Reason = "The number of arguments mismatch"; - return false; - } - - for (unsigned I = 0; I < ParamNum; ++I) { - Type *PTy = DirectCalleeType->getFunctionParamType(I); - Type *ATy = CS.getArgument(I)->getType(); - if (PTy == ATy) - continue; - if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) { - if (Reason) - *Reason = "Argument type mismatch"; - return false; - } - } - - DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to " - << F->getName() << "\n"); - return true; -} - // Indirect-call promotion heuristic. The direct targets are sorted based on // the count. Stop at the first target that is not promoted. std::vector @@ -317,7 +275,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( } const char *Reason = nullptr; - if (!isLegalToPromote(Inst, TargetFunction, &Reason)) { + if (!isLegalToPromote(CallSite(Inst), TargetFunction, &Reason)) { using namespace ore; ORE.emit([&]() { @@ -335,23 +293,11 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( return Ret; } -// Create a diamond structure for If_Then_Else. Also update the profile -// count. Do the fix-up for the invoke instruction. -static void createIfThenElse(Instruction *Inst, Function *DirectCallee, - uint64_t Count, uint64_t TotalCount, - BasicBlock **DirectCallBB, - BasicBlock **IndirectCallBB, - BasicBlock **MergeBB) { - CallSite CS(Inst); - Value *OrigCallee = CS.getCalledValue(); - - IRBuilder<> BBBuilder(Inst); - LLVMContext &Ctx = Inst->getContext(); - Value *BCI1 = - BBBuilder.CreateBitCast(OrigCallee, Type::getInt8PtrTy(Ctx), ""); - Value *BCI2 = - BBBuilder.CreateBitCast(DirectCallee, Type::getInt8PtrTy(Ctx), ""); - Value *PtrCmp = BBBuilder.CreateICmpEQ(BCI1, BCI2, ""); +Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst, + Function *DirectCallee, + uint64_t Count, uint64_t TotalCount, + bool AttachProfToDirectCall, + OptimizationRemarkEmitter *ORE) { uint64_t ElseCount = TotalCount - Count; uint64_t MaxCount = (Count >= ElseCount ? Count : ElseCount); @@ -359,231 +305,9 @@ static void createIfThenElse(Instruction *Inst, Function *DirectCallee, MDBuilder MDB(Inst->getContext()); MDNode *BranchWeights = MDB.createBranchWeights( scaleBranchCount(Count, Scale), scaleBranchCount(ElseCount, Scale)); - TerminatorInst *ThenTerm, *ElseTerm; - SplitBlockAndInsertIfThenElse(PtrCmp, Inst, &ThenTerm, &ElseTerm, - BranchWeights); - *DirectCallBB = ThenTerm->getParent(); - (*DirectCallBB)->setName("if.true.direct_targ"); - *IndirectCallBB = ElseTerm->getParent(); - (*IndirectCallBB)->setName("if.false.orig_indirect"); - *MergeBB = Inst->getParent(); - (*MergeBB)->setName("if.end.icp"); - // Special handing of Invoke instructions. - InvokeInst *II = dyn_cast(Inst); - if (!II) - return; - - // We don't need branch instructions for invoke. - ThenTerm->eraseFromParent(); - ElseTerm->eraseFromParent(); - - // Add jump from Merge BB to the NormalDest. This is needed for the newly - // created direct invoke stmt -- as its NormalDst will be fixed up to MergeBB. - BranchInst::Create(II->getNormalDest(), *MergeBB); -} - -// Find the PHI in BB that have the CallResult as the operand. -static bool getCallRetPHINode(BasicBlock *BB, Instruction *Inst) { - BasicBlock *From = Inst->getParent(); - for (auto &I : *BB) { - PHINode *PHI = dyn_cast(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(From); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - if (dyn_cast(V) == Inst) - return true; - } - return false; -} - -// This method fixes up PHI nodes in BB where BB is the UnwindDest of an -// invoke instruction. In BB, there may be PHIs with incoming block being -// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke -// instructions to its own BB, OrigBB is no longer the predecessor block of BB. -// Instead two new predecessors are added: IndirectCallBB and DirectCallBB, -// so the PHI node's incoming BBs need to be fixed up accordingly. -static void fixupPHINodeForUnwind(Instruction *Inst, BasicBlock *BB, - BasicBlock *OrigBB, - BasicBlock *IndirectCallBB, - BasicBlock *DirectCallBB) { - for (auto &I : *BB) { - PHINode *PHI = dyn_cast(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(OrigBB); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - PHI->addIncoming(V, IndirectCallBB); - PHI->setIncomingBlock(IX, DirectCallBB); - } -} - -// This method fixes up PHI nodes in BB where BB is the NormalDest of an -// invoke instruction. In BB, there may be PHIs with incoming block being -// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke -// instructions to its own BB, a new incoming edge will be added to the original -// NormalDstBB from the IndirectCallBB. -static void fixupPHINodeForNormalDest(Instruction *Inst, BasicBlock *BB, - BasicBlock *OrigBB, - BasicBlock *IndirectCallBB, - Instruction *NewInst) { - for (auto &I : *BB) { - PHINode *PHI = dyn_cast(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(OrigBB); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - if (dyn_cast(V) == Inst) { - PHI->setIncomingBlock(IX, IndirectCallBB); - PHI->addIncoming(NewInst, OrigBB); - continue; - } - PHI->addIncoming(V, IndirectCallBB); - } -} - -// Add a bitcast instruction to the direct-call return value if needed. -static Instruction *insertCallRetCast(const Instruction *Inst, - Instruction *DirectCallInst, - Function *DirectCallee) { - if (Inst->getType()->isVoidTy()) - return DirectCallInst; - - Type *CallRetType = Inst->getType(); - Type *FuncRetType = DirectCallee->getReturnType(); - if (FuncRetType == CallRetType) - return DirectCallInst; - - BasicBlock *InsertionBB; - if (CallInst *CI = dyn_cast(DirectCallInst)) - InsertionBB = CI->getParent(); - else - InsertionBB = (dyn_cast(DirectCallInst))->getNormalDest(); - - return (new BitCastInst(DirectCallInst, CallRetType, "", - InsertionBB->getTerminator())); -} - -// Create a DirectCall instruction in the DirectCallBB. -// Parameter Inst is the indirect-call (invoke) instruction. -// DirectCallee is the decl of the direct-call (invoke) target. -// DirecallBB is the BB that the direct-call (invoke) instruction is inserted. -// MergeBB is the bottom BB of the if-then-else-diamond after the -// transformation. For invoke instruction, the edges from DirectCallBB and -// IndirectCallBB to MergeBB are removed before this call (during -// createIfThenElse). Stores the pointer to the Instruction that cast -// the direct call in \p CastInst. -static Instruction *createDirectCallInst(const Instruction *Inst, - Function *DirectCallee, - BasicBlock *DirectCallBB, - BasicBlock *MergeBB, - Instruction *&CastInst) { - Instruction *NewInst = Inst->clone(); - if (CallInst *CI = dyn_cast(NewInst)) { - CI->setCalledFunction(DirectCallee); - CI->mutateFunctionType(DirectCallee->getFunctionType()); - } else { - // Must be an invoke instruction. Direct invoke's normal destination is - // fixed up to MergeBB. MergeBB is the place where return cast is inserted. - // Also since IndirectCallBB does not have an edge to MergeBB, there is no - // need to insert new PHIs into MergeBB. - InvokeInst *II = dyn_cast(NewInst); - assert(II); - II->setCalledFunction(DirectCallee); - II->mutateFunctionType(DirectCallee->getFunctionType()); - II->setNormalDest(MergeBB); - } - - DirectCallBB->getInstList().insert(DirectCallBB->getFirstInsertionPt(), - NewInst); - - // Clear the value profile data. - NewInst->setMetadata(LLVMContext::MD_prof, nullptr); - CallSite NewCS(NewInst); - FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); - unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); - for (unsigned I = 0; I < ParamNum; ++I) { - Type *ATy = NewCS.getArgument(I)->getType(); - Type *PTy = DirectCalleeType->getParamType(I); - if (ATy != PTy) { - BitCastInst *BI = new BitCastInst(NewCS.getArgument(I), PTy, "", NewInst); - NewCS.setArgument(I, BI); - } - } - - CastInst = insertCallRetCast(Inst, NewInst, DirectCallee); - return NewInst; -} - -// Create a PHI to unify the return values of calls. -static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult, - Function *DirectCallee) { - if (Inst->getType()->isVoidTy()) - return; - - if (Inst->use_empty()) - return; - - BasicBlock *RetValBB = CallResult->getParent(); - - BasicBlock *PHIBB; - if (InvokeInst *II = dyn_cast(CallResult)) - RetValBB = II->getNormalDest(); - - PHIBB = RetValBB->getSingleSuccessor(); - if (getCallRetPHINode(PHIBB, Inst)) - return; - - PHINode *CallRetPHI = PHINode::Create(Inst->getType(), 0); - PHIBB->getInstList().push_front(CallRetPHI); - Inst->replaceAllUsesWith(CallRetPHI); - CallRetPHI->addIncoming(Inst, Inst->getParent()); - CallRetPHI->addIncoming(CallResult, RetValBB); -} - -// This function does the actual indirect-call promotion transformation: -// For an indirect-call like: -// Ret = (*Foo)(Args); -// It transforms to: -// if (Foo == DirectCallee) -// Ret1 = DirectCallee(Args); -// else -// Ret2 = (*Foo)(Args); -// Ret = phi(Ret1, Ret2); -// It adds type casts for the args do not match the parameters and the return -// value. Branch weights metadata also updated. -// If \p AttachProfToDirectCall is true, a prof metadata is attached to the -// new direct call to contain \p Count. This is used by SamplePGO inliner to -// check callsite hotness. -// Returns the promoted direct call instruction. -Instruction *llvm::promoteIndirectCall(Instruction *Inst, - Function *DirectCallee, uint64_t Count, - uint64_t TotalCount, - bool AttachProfToDirectCall, - OptimizationRemarkEmitter *ORE) { - assert(DirectCallee != nullptr); - BasicBlock *BB = Inst->getParent(); - // Just to suppress the non-debug build warning. - (void)BB; - DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); - DEBUG(dbgs() << *BB << "\n"); - - BasicBlock *DirectCallBB, *IndirectCallBB, *MergeBB; - createIfThenElse(Inst, DirectCallee, Count, TotalCount, &DirectCallBB, - &IndirectCallBB, &MergeBB); - - // If the return type of the NewInst is not the same as the Inst, a CastInst - // is needed for type casting. Otherwise CastInst is the same as NewInst. - Instruction *CastInst = nullptr; Instruction *NewInst = - createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB, CastInst); + promoteCallWithIfThenElse(CallSite(Inst), DirectCallee, BranchWeights); if (AttachProfToDirectCall) { SmallVector Weights; @@ -592,33 +316,6 @@ Instruction *llvm::promoteIndirectCall(Instruction *Inst, NewInst->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); } - // Move Inst from MergeBB to IndirectCallBB. - Inst->removeFromParent(); - IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(), - Inst); - - if (InvokeInst *II = dyn_cast(Inst)) { - // At this point, the original indirect invoke instruction has the original - // UnwindDest and NormalDest. For the direct invoke instruction, the - // NormalDest points to MergeBB, and MergeBB jumps to the original - // NormalDest. MergeBB might have a new bitcast instruction for the return - // value. The PHIs are with the original NormalDest. Since we now have two - // incoming edges to NormalDest and UnwindDest, we have to do some fixups. - // - // UnwindDest will not use the return value. So pass nullptr here. - fixupPHINodeForUnwind(Inst, II->getUnwindDest(), MergeBB, IndirectCallBB, - DirectCallBB); - // We don't need to update the operand from NormalDest for DirectCallBB. - // Pass nullptr here. - fixupPHINodeForNormalDest(Inst, II->getNormalDest(), MergeBB, - IndirectCallBB, CastInst); - } - - insertCallRetPHI(Inst, CastInst, DirectCallee); - - DEBUG(dbgs() << "\n== Basic Blocks After ==\n"); - DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n"); - using namespace ore; if (ORE) @@ -639,8 +336,8 @@ uint32_t ICallPromotionFunc::tryToPromote( for (auto &C : Candidates) { uint64_t Count = C.Count; - promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, SamplePGO, - &ORE); + pgo::promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, + SamplePGO, &ORE); assert(TotalCount >= Count); TotalCount -= Count; NumOfPGOICallPromotion++; diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index f3bf0d8c248..972e47f9270 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -5,6 +5,7 @@ add_llvm_library(LLVMTransformUtils BreakCriticalEdges.cpp BuildLibCalls.cpp BypassSlowDivision.cpp + CallPromotionUtils.cpp CloneFunction.cpp CloneModule.cpp CodeExtractor.cpp diff --git a/lib/Transforms/Utils/CallPromotionUtils.cpp b/lib/Transforms/Utils/CallPromotionUtils.cpp new file mode 100644 index 00000000000..eb3139ce429 --- /dev/null +++ b/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -0,0 +1,328 @@ +//===- CallPromotionUtils.cpp - Utilities for call promotion ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities useful for promoting indirect call sites to +// direct call sites. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "call-promotion-utils" + +/// Fix-up phi nodes in an invoke instruction's normal destination. +/// +/// After versioning an invoke instruction, values coming from the original +/// block will now either be coming from the original block or the "else" block. +static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock, + BasicBlock *ElseBlock, + Instruction *NewInst) { + for (auto &I : *Invoke->getNormalDest()) { + auto *Phi = dyn_cast(&I); + if (!Phi) + break; + int Idx = Phi->getBasicBlockIndex(OrigBlock); + if (Idx == -1) + continue; + Value *V = Phi->getIncomingValue(Idx); + if (dyn_cast(V) == Invoke) { + Phi->setIncomingBlock(Idx, ElseBlock); + Phi->addIncoming(NewInst, OrigBlock); + continue; + } + Phi->addIncoming(V, ElseBlock); + } +} + +/// Fix-up phi nodes in an invoke instruction's unwind destination. +/// +/// After versioning an invoke instruction, values coming from the original +/// block will now be coming from either the "then" block or the "else" block. +static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock, + BasicBlock *ThenBlock, + BasicBlock *ElseBlock) { + for (auto &I : *Invoke->getUnwindDest()) { + auto *Phi = dyn_cast(&I); + if (!Phi) + break; + int Idx = Phi->getBasicBlockIndex(OrigBlock); + if (Idx == -1) + continue; + auto *V = Phi->getIncomingValue(Idx); + Phi->setIncomingBlock(Idx, ThenBlock); + Phi->addIncoming(V, ElseBlock); + } +} + +/// Get the phi node having the returned value of a call or invoke instruction +/// as it's operand. +static bool getRetPhiNode(Instruction *Inst, BasicBlock *Block) { + BasicBlock *FromBlock = Inst->getParent(); + for (auto &I : *Block) { + PHINode *PHI = dyn_cast(&I); + if (!PHI) + break; + int Idx = PHI->getBasicBlockIndex(FromBlock); + if (Idx == -1) + continue; + auto *V = PHI->getIncomingValue(Idx); + if (V == Inst) + return true; + } + return false; +} + +/// Create a phi node for the returned value of a call or invoke instruction. +/// +/// After versioning a call or invoke instruction that returns a value, we have +/// to merge the value of the original and new instructions. We do this by +/// creating a phi node and replacing uses of the original instruction with this +/// phi node. +static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst) { + + if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty()) + return; + + BasicBlock *RetValBB = NewInst->getParent(); + if (auto *Invoke = dyn_cast(NewInst)) + RetValBB = Invoke->getNormalDest(); + BasicBlock *PhiBB = RetValBB->getSingleSuccessor(); + + if (getRetPhiNode(OrigInst, PhiBB)) + return; + + IRBuilder<> Builder(&PhiBB->front()); + PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0); + SmallVector UsersToUpdate; + for (User *U : OrigInst->users()) + UsersToUpdate.push_back(U); + for (User *U : UsersToUpdate) + U->replaceUsesOfWith(OrigInst, Phi); + Phi->addIncoming(OrigInst, OrigInst->getParent()); + Phi->addIncoming(NewInst, RetValBB); +} + +/// Cast a call or invoke instruction to the given type. +/// +/// When promoting a call site, the return type of the call site might not match +/// that of the callee. If this is the case, we have to cast the returned value +/// to the correct type. The location of the cast depends on if we have a call +/// or invoke instruction. +Instruction *createRetBitCast(CallSite CS, Type *RetTy) { + + // Save the users of the calling instruction. These uses will be changed to + // use the bitcast after we create it. + SmallVector UsersToUpdate; + for (User *U : CS.getInstruction()->users()) + UsersToUpdate.push_back(U); + + // Determine an appropriate location to create the bitcast for the return + // value. The location depends on if we have a call or invoke instruction. + Instruction *InsertBefore = nullptr; + if (auto *Invoke = dyn_cast(CS.getInstruction())) + InsertBefore = &*Invoke->getNormalDest()->getFirstInsertionPt(); + else + InsertBefore = &*std::next(CS.getInstruction()->getIterator()); + + // Bitcast the return value to the correct type. + auto *Cast = CastInst::Create(Instruction::BitCast, CS.getInstruction(), + RetTy, "", InsertBefore); + + // Replace all the original uses of the calling instruction with the bitcast. + for (User *U : UsersToUpdate) + U->replaceUsesOfWith(CS.getInstruction(), Cast); + + return Cast; +} + +/// Predicate and clone the given call site. +/// +/// This function creates an if-then-else structure at the location of the call +/// site. The "if" condition compares the call site's called value to the given +/// callee. The original call site is moved into the "else" block, and a clone +/// of the call site is placed in the "then" block. The cloned instruction is +/// returned. +static Instruction *versionCallSite(CallSite CS, Value *Callee, + MDNode *BranchWeights, + BasicBlock *&ThenBlock, + BasicBlock *&ElseBlock, + BasicBlock *&MergeBlock) { + + IRBuilder<> Builder(CS.getInstruction()); + Instruction *OrigInst = CS.getInstruction(); + + // Create the compare. The called value and callee must have the same type to + // be compared. + auto *LHS = + Builder.CreateBitCast(CS.getCalledValue(), Builder.getInt8PtrTy()); + auto *RHS = Builder.CreateBitCast(Callee, Builder.getInt8PtrTy()); + auto *Cond = Builder.CreateICmpEQ(LHS, RHS); + + // Create an if-then-else structure. The original instruction is moved into + // the "else" block, and a clone of the original instruction is placed in the + // "then" block. + TerminatorInst *ThenTerm = nullptr; + TerminatorInst *ElseTerm = nullptr; + SplitBlockAndInsertIfThenElse(Cond, CS.getInstruction(), &ThenTerm, &ElseTerm, + BranchWeights); + ThenBlock = ThenTerm->getParent(); + ElseBlock = ElseTerm->getParent(); + MergeBlock = OrigInst->getParent(); + + ThenBlock->setName("if.true.direct_targ"); + ElseBlock->setName("if.false.orig_indirect"); + MergeBlock->setName("if.end.icp"); + + Instruction *NewInst = OrigInst->clone(); + OrigInst->moveBefore(ElseTerm); + NewInst->insertBefore(ThenTerm); + + // If the original call site is an invoke instruction, we have extra work to + // do since invoke instructions are terminating. + if (auto *OrigInvoke = dyn_cast(OrigInst)) { + auto *NewInvoke = cast(NewInst); + + // Invoke instructions are terminating, so we don't need the terminator + // instructions that were just created. + ThenTerm->eraseFromParent(); + ElseTerm->eraseFromParent(); + + // Branch from the "merge" block to the original normal destination. + Builder.SetInsertPoint(MergeBlock); + Builder.CreateBr(OrigInvoke->getNormalDest()); + + // Now set the normal destination of new the invoke instruction to be the + // "merge" block. + NewInvoke->setNormalDest(MergeBlock); + } + + return NewInst; +} + +bool llvm::isLegalToPromote(CallSite CS, Function *Callee, + const char **FailureReason) { + assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); + + // Check the return type. The callee's return value type must be bitcast + // compatible with the call site's type. + Type *CallRetTy = CS.getInstruction()->getType(); + Type *FuncRetTy = Callee->getReturnType(); + if (CallRetTy != FuncRetTy) + if (!CastInst::isBitCastable(FuncRetTy, CallRetTy)) { + if (FailureReason) + *FailureReason = "Return type mismatch"; + return false; + } + + // The number of formal arguments of the callee. + unsigned NumParams = Callee->getFunctionType()->getNumParams(); + + // Check the number of arguments. The callee and call site must agree on the + // number of arguments. + if (CS.arg_size() != NumParams && !Callee->isVarArg()) { + if (FailureReason) + *FailureReason = "The number of arguments mismatch"; + return false; + } + + // Check the argument types. The callee's formal argument types must be + // bitcast compatible with the corresponding actual argument types of the call + // site. + for (unsigned I = 0; I < NumParams; ++I) { + Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I); + Type *ActualTy = CS.getArgument(I)->getType(); + if (FormalTy == ActualTy) + continue; + if (!CastInst::isBitCastable(ActualTy, FormalTy)) { + if (FailureReason) + *FailureReason = "Argument type mismatch"; + return false; + } + } + + return true; +} + +static void promoteCall(CallSite CS, Function *Callee, Instruction *&Cast) { + assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); + + // Set the called function of the call site to be the given callee. + CS.setCalledFunction(Callee); + + // Since the call site will no longer be direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and !callees + // metadata. + CS.getInstruction()->setMetadata(LLVMContext::MD_prof, nullptr); + CS.getInstruction()->setMetadata(LLVMContext::MD_callees, nullptr); + + // If the function type of the call site matches that of the callee, no + // additional work is required. + if (CS.getFunctionType() == Callee->getFunctionType()) + return; + + // Save the return types of the call site and callee. + Type *CallSiteRetTy = CS.getInstruction()->getType(); + Type *CalleeRetTy = Callee->getReturnType(); + + // Change the function type of the call site the match that of the callee. + CS.mutateFunctionType(Callee->getFunctionType()); + + // Inspect the arguments of the call site. If an argument's type doesn't + // match the corresponding formal argument's type in the callee, bitcast it + // to the correct type. + for (Use &U : CS.args()) { + unsigned ArgNo = CS.getArgumentNo(&U); + Type *FormalTy = Callee->getFunctionType()->getParamType(ArgNo); + Type *ActualTy = U.get()->getType(); + if (FormalTy != ActualTy) { + auto *Cast = CastInst::Create(Instruction::BitCast, U.get(), FormalTy, "", + CS.getInstruction()); + CS.setArgument(ArgNo, Cast); + } + } + + // If the return type of the call site doesn't match that of the callee, cast + // the returned value to the appropriate type. + if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) + Cast = createRetBitCast(CS, CallSiteRetTy); +} + +Instruction *llvm::promoteCallWithIfThenElse(CallSite CS, Function *Callee, + MDNode *BranchWeights) { + + // Version the indirect call site. If the called value is equal to the given + // callee, 'NewInst' will be executed, otherwise the original call site will + // be executed. + BasicBlock *ThenBlock, *ElseBlock, *MergeBlock; + Instruction *NewInst = versionCallSite(CS, Callee, BranchWeights, ThenBlock, + ElseBlock, MergeBlock); + + // Promote 'NewInst' so that it directly calls the desired function. + Instruction *Cast = NewInst; + promoteCall(CallSite(NewInst), Callee, Cast); + + // If the original call site is an invoke instruction, we have to fix-up phi + // nodes in the invoke's normal and unwind destinations. + if (auto *OrigInvoke = dyn_cast(CS.getInstruction())) { + fixupPHINodeForNormalDest(OrigInvoke, MergeBlock, ElseBlock, Cast); + fixupPHINodeForUnwindDest(OrigInvoke, MergeBlock, ThenBlock, ElseBlock); + } + + // Create a phi node for the returned value of the call site. + createRetPHINode(CS.getInstruction(), Cast ? Cast : NewInst); + + // Return the new direct call. + return NewInst; +} + +#undef DEBUG_TYPE