mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-21 18:22:53 +01:00
[LV] Make use of PatternMatchers in getReductionPatternCost. NFC
Pulled out of D106166, this modifies getReductionPatternCost to use PatternMatchers, hopefully simplifying the code a little.
This commit is contained in:
parent
fd3020376e
commit
ddd62877c6
@ -7127,6 +7127,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
|
||||
|
||||
Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
|
||||
Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) {
|
||||
using namespace llvm::PatternMatch;
|
||||
// Early exit for no inloop reductions
|
||||
if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty))
|
||||
return None;
|
||||
@ -7145,13 +7146,12 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
|
||||
// it is not we return an invalid cost specifying the orignal cost method
|
||||
// should be used.
|
||||
Instruction *RetI = I;
|
||||
if ((RetI->getOpcode() == Instruction::SExt ||
|
||||
RetI->getOpcode() == Instruction::ZExt)) {
|
||||
if (match(RetI, m_ZExtOrSExt(m_Value()))) {
|
||||
if (!RetI->hasOneUser())
|
||||
return None;
|
||||
RetI = RetI->user_back();
|
||||
}
|
||||
if (RetI->getOpcode() == Instruction::Mul &&
|
||||
if (match(RetI, m_Mul(m_Value(), m_Value())) &&
|
||||
RetI->user_back()->getOpcode() == Instruction::Add) {
|
||||
if (!RetI->hasOneUser())
|
||||
return None;
|
||||
@ -7183,8 +7183,10 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
|
||||
|
||||
VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy);
|
||||
|
||||
if (RedOp && (isa<SExtInst>(RedOp) || isa<ZExtInst>(RedOp)) &&
|
||||
Instruction *Op0, *Op1;
|
||||
if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) &&
|
||||
!TheLoop->isLoopInvariant(RedOp)) {
|
||||
// Matched reduce(ext(A))
|
||||
bool IsUnsigned = isa<ZExtInst>(RedOp);
|
||||
auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy);
|
||||
InstructionCost RedCost = TTI.getExtendedAddReductionCost(
|
||||
@ -7196,22 +7198,20 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
|
||||
TTI::CastContextHint::None, CostKind, RedOp);
|
||||
if (RedCost.isValid() && RedCost < BaseCost + ExtCost)
|
||||
return I == RetI ? RedCost : 0;
|
||||
} else if (RedOp && RedOp->getOpcode() == Instruction::Mul) {
|
||||
Instruction *Mul = RedOp;
|
||||
Instruction *Op0 = dyn_cast<Instruction>(Mul->getOperand(0));
|
||||
Instruction *Op1 = dyn_cast<Instruction>(Mul->getOperand(1));
|
||||
if (Op0 && Op1 && (isa<SExtInst>(Op0) || isa<ZExtInst>(Op0)) &&
|
||||
} else if (RedOp &&
|
||||
match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) {
|
||||
if (match(Op0, m_ZExtOrSExt(m_Value())) &&
|
||||
Op0->getOpcode() == Op1->getOpcode() &&
|
||||
Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
|
||||
!TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) {
|
||||
bool IsUnsigned = isa<ZExtInst>(Op0);
|
||||
auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy);
|
||||
// reduce(mul(ext, ext))
|
||||
// Matched reduce(mul(ext, ext))
|
||||
InstructionCost ExtCost =
|
||||
TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType,
|
||||
TTI::CastContextHint::None, CostKind, Op0);
|
||||
InstructionCost MulCost =
|
||||
TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind);
|
||||
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
|
||||
|
||||
InstructionCost RedCost = TTI.getExtendedAddReductionCost(
|
||||
/*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
|
||||
@ -7220,8 +7220,9 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
|
||||
if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost)
|
||||
return I == RetI ? RedCost : 0;
|
||||
} else {
|
||||
// Matched reduce(mul())
|
||||
InstructionCost MulCost =
|
||||
TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind);
|
||||
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
|
||||
|
||||
InstructionCost RedCost = TTI.getExtendedAddReductionCost(
|
||||
/*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy,
|
||||
|
Loading…
Reference in New Issue
Block a user