diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h index f2a07006da4..68fbf640994 100644 --- a/include/llvm/Analysis/TargetTransformInfo.h +++ b/include/llvm/Analysis/TargetTransformInfo.h @@ -216,9 +216,23 @@ public: /// other context they may not be folded. This routine can distinguish such /// cases. /// + /// \p Operands is a list of operands which can be a result of transformations + /// of the current operands. The number of the operands on the list must equal + /// to the number of the current operands the IR user has. Their order on the + /// list must be the same as the order of the current operands the IR user + /// has. + /// /// The returned cost is defined in terms of \c TargetCostConstants, see its /// comments for a detailed explanation of the cost values. - int getUserCost(const User *U) const; + int getUserCost(const User *U, ArrayRef Operands) const; + + /// \brief This is a helper function which calls the two-argument getUserCost + /// with \p Operands which are the current operands U has. + int getUserCost(const User *U) const { + SmallVector Operands(U->value_op_begin(), + U->value_op_end()); + return getUserCost(U, Operands); + } /// \brief Return true if branch divergence exists. /// @@ -824,7 +838,8 @@ public: ArrayRef Arguments) = 0; virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize) = 0; - virtual int getUserCost(const User *U) = 0; + virtual int + getUserCost(const User *U, ArrayRef Operands) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; @@ -1000,7 +1015,9 @@ public: ArrayRef Arguments) override { return Impl.getIntrinsicCost(IID, RetTy, Arguments); } - int getUserCost(const User *U) override { return Impl.getUserCost(U); } + int getUserCost(const User *U, ArrayRef Operands) override { + return Impl.getUserCost(U, Operands); + } bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); } bool isSourceOfDivergence(const Value *V) override { return Impl.isSourceOfDivergence(V); diff --git a/include/llvm/Analysis/TargetTransformInfoImpl.h b/include/llvm/Analysis/TargetTransformInfoImpl.h index defd9e73bc0..0246fc1c02c 100644 --- a/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -685,14 +685,14 @@ public: return static_cast(this)->getIntrinsicCost(IID, RetTy, ParamTys); } - unsigned getUserCost(const User *U) { + unsigned getUserCost(const User *U, ArrayRef Operands) { if (isa(U)) return TTI::TCC_Free; // Model all PHI nodes as free. if (const GEPOperator *GEP = dyn_cast(U)) { - SmallVector Indices(GEP->idx_begin(), GEP->idx_end()); - return static_cast(this)->getGEPCost( - GEP->getSourceElementType(), GEP->getPointerOperand(), Indices); + return static_cast(this)->getGEPCost(GEP->getSourceElementType(), + GEP->getPointerOperand(), + Operands.drop_front()); } if (auto CS = ImmutableCallSite(U)) { diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 91160954bd9..f938a9a5206 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -89,8 +89,9 @@ TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI, return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); } -int TargetTransformInfo::getUserCost(const User *U) const { - int Cost = TTIImpl->getUserCost(U); +int TargetTransformInfo::getUserCost(const User *U, + ArrayRef Operands) const { + int Cost = TTIImpl->getUserCost(U, Operands); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } diff --git a/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index d3848104d2b..4135d0cec70 100644 --- a/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -46,8 +46,9 @@ unsigned HexagonTTIImpl::getCacheLineSize() const { return getST()->getL1CacheLineSize(); } -int HexagonTTIImpl::getUserCost(const User *U) { - auto isCastFoldedIntoLoad = [] (const CastInst *CI) -> bool { +int HexagonTTIImpl::getUserCost(const User *U, + ArrayRef Operands) { + auto isCastFoldedIntoLoad = [](const CastInst *CI) -> bool { if (!CI->isIntegerCast()) return false; const LoadInst *LI = dyn_cast(CI->getOperand(0)); @@ -67,5 +68,5 @@ int HexagonTTIImpl::getUserCost(const User *U) { if (const CastInst *CI = dyn_cast(U)) if (isCastFoldedIntoLoad(CI)) return TargetTransformInfo::TCC_Free; - return BaseT::getUserCost(U); + return BaseT::getUserCost(U, Operands); } diff --git a/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/lib/Target/Hexagon/HexagonTargetTransformInfo.h index 08faa2acd90..f08a2731057 100644 --- a/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -62,7 +62,7 @@ public: /// @} - int getUserCost(const User *U); + int getUserCost(const User *U, ArrayRef Operands); }; } // end namespace llvm