diff --git a/include/llvm/IR/Instructions.h b/include/llvm/IR/Instructions.h index 3e0c6d803d2..82833658c41 100644 --- a/include/llvm/IR/Instructions.h +++ b/include/llvm/IR/Instructions.h @@ -3435,6 +3435,52 @@ public: } }; +/// A wrapper class to simplify modification of SwitchInst cases along with +/// their prof branch_weights metadata. +class SwitchInstProfUpdateWrapper { + SwitchInst &SI; + Optional > Weights; + bool Changed = false; + +protected: + static MDNode *getProfBranchWeightsMD(const SwitchInst &SI); + + MDNode *buildProfBranchWeightsMD(); + + Optional > getProfBranchWeights(); + +public: + using CaseWeightOpt = Optional; + SwitchInst *operator->() { return &SI; } + SwitchInst &operator*() { return SI; } + operator SwitchInst *() { return &SI; } + + SwitchInstProfUpdateWrapper(SwitchInst &SI) + : SI(SI), Weights(getProfBranchWeights()) {} + + ~SwitchInstProfUpdateWrapper() { + if (Changed) + SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD()); + } + + /// Delegate the call to the underlying SwitchInst::removeCase() and remove + /// correspondent branch weight. + SwitchInst::CaseIt removeCase(SwitchInst::CaseIt I); + + /// Delegate the call to the underlying SwitchInst::addCase() and set the + /// specified branch weight for the added case. + void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W); + + /// Delegate the call to the underlying SwitchInst::eraseFromParent() and mark + /// this object to not touch the underlying SwitchInst in destructor. + SymbolTableList::iterator eraseFromParent(); + + void setSuccessorWeight(unsigned idx, CaseWeightOpt W); + CaseWeightOpt getSuccessorWeight(unsigned idx); + + static CaseWeightOpt getSuccessorWeight(const SwitchInst &SI, unsigned idx); +}; + template <> struct OperandTraits : public HungoffOperandTraits<2> { }; diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp index 9dc753e960c..8812df35e26 100644 --- a/lib/IR/Instructions.cpp +++ b/lib/IR/Instructions.cpp @@ -3870,6 +3870,126 @@ void SwitchInst::growOperands() { growHungoffUses(ReservedSpace); } +MDNode * +SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) { + if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof)) + if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) + if (MDName->getString() == "branch_weights") + return ProfileData; + return nullptr; +} + +MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { + assert(Changed && "called only if metadata has changed"); + + if (!Weights) + return nullptr; + + assert(SI.getNumSuccessors() == Weights->size() && + "num of prof branch_weights must accord with num of successors"); + + bool AllZeroes = + all_of(Weights.getValue(), [](uint32_t W) { return W == 0; }); + + if (AllZeroes || Weights.getValue().size() < 2) + return nullptr; + + return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights); +} + +Optional > +SwitchInstProfUpdateWrapper::getProfBranchWeights() { + MDNode *ProfileData = getProfBranchWeightsMD(SI); + if (!ProfileData) + return None; + + SmallVector Weights; + for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { + ConstantInt *C = mdconst::extract(ProfileData->getOperand(CI)); + uint32_t CW = C->getValue().getZExtValue(); + Weights.push_back(CW); + } + return Weights; +} + +SwitchInst::CaseIt +SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) { + if (Weights) { + assert(SI.getNumSuccessors() == Weights->size() && + "num of prof branch_weights must accord with num of successors"); + Changed = true; + // Copy the last case to the place of the removed one and shrink. + // This is tightly coupled with the way SwitchInst::removeCase() removes + // the cases in SwitchInst::removeCase(CaseIt). + Weights.getValue()[I->getCaseIndex() + 1] = Weights.getValue().back(); + Weights.getValue().pop_back(); + } + return SI.removeCase(I); +} + +void SwitchInstProfUpdateWrapper::addCase( + ConstantInt *OnVal, BasicBlock *Dest, + SwitchInstProfUpdateWrapper::CaseWeightOpt W) { + SI.addCase(OnVal, Dest); + + if (!Weights && W && *W) { + Changed = true; + Weights = SmallVector(SI.getNumSuccessors(), 0); + Weights.getValue()[SI.getNumSuccessors() - 1] = *W; + } else if (Weights) { + Changed = true; + Weights.getValue().push_back(W ? *W : 0); + } + if (Weights) + assert(SI.getNumSuccessors() == Weights->size() && + "num of prof branch_weights must accord with num of successors"); +} + +SymbolTableList::iterator +SwitchInstProfUpdateWrapper::eraseFromParent() { + // Instruction is erased. Mark as unchanged to not touch it in the destructor. + Changed = false; + + if (Weights) + Weights->resize(0); + return SI.eraseFromParent(); +} + +SwitchInstProfUpdateWrapper::CaseWeightOpt +SwitchInstProfUpdateWrapper::getSuccessorWeight(unsigned idx) { + if (!Weights) + return None; + return Weights.getValue()[idx]; +} + +void SwitchInstProfUpdateWrapper::setSuccessorWeight( + unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) { + if (!W) + return; + + if (!Weights && *W) + Weights = SmallVector(SI.getNumSuccessors(), 0); + + if (Weights) { + auto &OldW = Weights.getValue()[idx]; + if (*W != OldW) { + Changed = true; + OldW = *W; + } + } +} + +SwitchInstProfUpdateWrapper::CaseWeightOpt +SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, + unsigned idx) { + if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) + return mdconst::extract(ProfileData->getOperand(idx + 1)) + ->getValue() + .getZExtValue(); + + return None; +} + //===----------------------------------------------------------------------===// // IndirectBrInst Implementation //===----------------------------------------------------------------------===//