diff --git a/include/llvm/IR/Instruction.h b/include/llvm/IR/Instruction.h index 87457b291e7..9dbe2ec0e90 100644 --- a/include/llvm/IR/Instruction.h +++ b/include/llvm/IR/Instruction.h @@ -252,6 +252,9 @@ public: /// Returns false if no metadata was found. bool extractProfTotalWeight(uint64_t &TotalVal) const; + /// Updates branch_weights metadata by scaling it by \p S / \p T. + void updateProfWeight(uint64_t S, uint64_t T); + /// Set the debug location information for this instruction. void setDebugLoc(DebugLoc Loc) { DbgLoc = std::move(Loc); } diff --git a/lib/IR/Instruction.cpp b/lib/IR/Instruction.cpp index 2ca496e2dc7..fc453d4f8d7 100644 --- a/lib/IR/Instruction.cpp +++ b/lib/IR/Instruction.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" using namespace llvm; @@ -629,3 +630,25 @@ Instruction *Instruction::clone() const { New->copyMetadata(*this); return New; } + +void Instruction::updateProfWeight(uint64_t S, uint64_t T) { + auto *ProfileData = getMetadata(LLVMContext::MD_prof); + if (ProfileData == nullptr) + return; + + auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); + if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) + return; + + SmallVector Weights; + for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) { + // Using APInt::div may be expensive, but most cases should fit in 64 bits. + APInt Val(128, mdconst::dyn_extract(ProfileData->getOperand(i)) + ->getValue() + .getZExtValue()); + Val *= APInt(128, S); + Weights.push_back(Val.udiv(APInt(128, T)).getLimitedValue()); + } + MDBuilder MDB(getContext()); + setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +} diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index a90ef3e5054..0f08d7f66fd 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" @@ -1425,28 +1426,55 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock, ClonedBBs); } +/// Update the branch metadata for cloned call instructions. +static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, + const Optional &CalleeEntryCount, + const Instruction *TheCall) { + if (!CalleeEntryCount.hasValue() || CalleeEntryCount.getValue() < 1) + return; + Optional CallSiteCount = + ProfileSummaryInfo::getProfileCount(TheCall, nullptr); + uint64_t CallCount = + std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0, + CalleeEntryCount.getValue()); + + for (auto const &Entry : VMap) + if (isa(Entry.first) && &*Entry.second != nullptr && + isa(Entry.second)) + cast(Entry.second) + ->updateProfWeight(CallCount, CalleeEntryCount.getValue()); + for (BasicBlock &BB : *Callee) + // No need to update the callsite if it is pruned during inlining. + if (VMap.count(&BB)) + for (Instruction &I : BB) + if (CallInst *CI = dyn_cast(&I)) + CI->updateProfWeight(CalleeEntryCount.getValue() - CallCount, + CalleeEntryCount.getValue()); +} + /// Update the entry count of callee after inlining. /// /// The callsite's block count is subtracted from the callee's function entry /// count. -static void updateCalleeCount(BlockFrequencyInfo &CallerBFI, BasicBlock *CallBB, - Function *Callee) { +static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, + Instruction *CallInst, Function *Callee) { // If the callee has a original count of N, and the estimated count of // callsite is M, the new callee count is set to N - M. M is estimated from // the caller's entry count, its entry block frequency and the block frequency // of the callsite. Optional CalleeCount = Callee->getEntryCount(); - if (!CalleeCount) + if (!CalleeCount.hasValue()) return; - Optional CallSiteCount = CallerBFI.getBlockProfileCount(CallBB); - if (!CallSiteCount) + Optional CallCount = + ProfileSummaryInfo::getProfileCount(CallInst, CallerBFI); + if (!CallCount.hasValue()) return; // Since CallSiteCount is an estimate, it could exceed the original callee // count and has to be set to 0. - if (CallSiteCount.getValue() > CalleeCount.getValue()) + if (CallCount.getValue() > CalleeCount.getValue()) Callee->setEntryCount(0); else - Callee->setEntryCount(CalleeCount.getValue() - CallSiteCount.getValue()); + Callee->setEntryCount(CalleeCount.getValue() - CallCount.getValue()); } /// This function inlines the called function into the basic block of the @@ -1636,13 +1664,14 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Remember the first block that is newly cloned over. FirstNewBlock = LastBlock; ++FirstNewBlock; - if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) { + if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) // Update the BFI of blocks cloned into the caller. updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, CalledFunc->front()); - // Update the profile count of callee. - updateCalleeCount(*IFI.CallerBFI, OrigBB, CalledFunc); - } + + updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall); + // Update the profile count of callee. + updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc); // Inject byval arguments initialization. for (std::pair &Init : ByValInit) diff --git a/test/Transforms/Inline/prof-update.ll b/test/Transforms/Inline/prof-update.ll new file mode 100644 index 00000000000..38fcc7e4599 --- /dev/null +++ b/test/Transforms/Inline/prof-update.ll @@ -0,0 +1,39 @@ +; RUN: opt < %s -inline -S | FileCheck %s +; Checks if inliner updates branch_weights annotation for call instructions. + +declare void @ext(); +declare void @ext1(); + +; CHECK: define void @callee(i32 %n) !prof ![[ENTRY_COUNT:[0-9]*]] +define void @callee(i32 %n) !prof !1 { + %cond = icmp sle i32 %n, 10 + br i1 %cond, label %cond_true, label %cond_false +cond_true: +; ext1 is optimized away, thus not updated. +; CHECK: call void @ext1(), !prof ![[COUNT_CALLEE1:[0-9]*]] + call void @ext1(), !prof !2 + ret void +cond_false: +; ext is cloned and updated. +; CHECK: call void @ext(), !prof ![[COUNT_CALLEE:[0-9]*]] + call void @ext(), !prof !2 + ret void +} + +; CHECK: define void @caller() +define void @caller() { +; CHECK: call void @ext(), !prof ![[COUNT_CALLER:[0-9]*]] + call void @callee(i32 15), !prof !3 + ret void +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 2000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i64 2000} +!3 = !{!"branch_weights", i64 400} +attributes #0 = { alwaysinline } +; CHECK: ![[ENTRY_COUNT]] = !{!"function_entry_count", i64 600} +; CHECK: ![[COUNT_CALLEE1]] = !{!"branch_weights", i64 2000} +; CHECK: ![[COUNT_CALLEE]] = !{!"branch_weights", i32 1200} +; CHECK: ![[COUNT_CALLER]] = !{!"branch_weights", i32 800}