1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 02:33:06 +01:00

[LV] Make use of PatternMatchers in getReductionPatternCost. NFC

Pulled out of D106166, this modifies getReductionPatternCost to use
PatternMatchers, hopefully simplifying the code a little.
This commit is contained in:
David Green 2021-07-21 11:34:30 +01:00
parent fd3020376e
commit ddd62877c6

View File

@ -7127,6 +7127,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) { Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) {
using namespace llvm::PatternMatch;
// Early exit for no inloop reductions // Early exit for no inloop reductions
if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty)) if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty))
return None; return None;
@ -7145,13 +7146,12 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
// it is not we return an invalid cost specifying the orignal cost method // it is not we return an invalid cost specifying the orignal cost method
// should be used. // should be used.
Instruction *RetI = I; Instruction *RetI = I;
if ((RetI->getOpcode() == Instruction::SExt || if (match(RetI, m_ZExtOrSExt(m_Value()))) {
RetI->getOpcode() == Instruction::ZExt)) {
if (!RetI->hasOneUser()) if (!RetI->hasOneUser())
return None; return None;
RetI = RetI->user_back(); RetI = RetI->user_back();
} }
if (RetI->getOpcode() == Instruction::Mul && if (match(RetI, m_Mul(m_Value(), m_Value())) &&
RetI->user_back()->getOpcode() == Instruction::Add) { RetI->user_back()->getOpcode() == Instruction::Add) {
if (!RetI->hasOneUser()) if (!RetI->hasOneUser())
return None; return None;
@ -7183,8 +7183,10 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy); VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy);
if (RedOp && (isa<SExtInst>(RedOp) || isa<ZExtInst>(RedOp)) && Instruction *Op0, *Op1;
if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) &&
!TheLoop->isLoopInvariant(RedOp)) { !TheLoop->isLoopInvariant(RedOp)) {
// Matched reduce(ext(A))
bool IsUnsigned = isa<ZExtInst>(RedOp); bool IsUnsigned = isa<ZExtInst>(RedOp);
auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy); auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy);
InstructionCost RedCost = TTI.getExtendedAddReductionCost( InstructionCost RedCost = TTI.getExtendedAddReductionCost(
@ -7196,22 +7198,20 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
TTI::CastContextHint::None, CostKind, RedOp); TTI::CastContextHint::None, CostKind, RedOp);
if (RedCost.isValid() && RedCost < BaseCost + ExtCost) if (RedCost.isValid() && RedCost < BaseCost + ExtCost)
return I == RetI ? RedCost : 0; return I == RetI ? RedCost : 0;
} else if (RedOp && RedOp->getOpcode() == Instruction::Mul) { } else if (RedOp &&
Instruction *Mul = RedOp; match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) {
Instruction *Op0 = dyn_cast<Instruction>(Mul->getOperand(0)); if (match(Op0, m_ZExtOrSExt(m_Value())) &&
Instruction *Op1 = dyn_cast<Instruction>(Mul->getOperand(1));
if (Op0 && Op1 && (isa<SExtInst>(Op0) || isa<ZExtInst>(Op0)) &&
Op0->getOpcode() == Op1->getOpcode() && Op0->getOpcode() == Op1->getOpcode() &&
Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() && Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
!TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) { !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) {
bool IsUnsigned = isa<ZExtInst>(Op0); bool IsUnsigned = isa<ZExtInst>(Op0);
auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy); auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy);
// reduce(mul(ext, ext)) // Matched reduce(mul(ext, ext))
InstructionCost ExtCost = InstructionCost ExtCost =
TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType, TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType,
TTI::CastContextHint::None, CostKind, Op0); TTI::CastContextHint::None, CostKind, Op0);
InstructionCost MulCost = InstructionCost MulCost =
TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind); TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getExtendedAddReductionCost( InstructionCost RedCost = TTI.getExtendedAddReductionCost(
/*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
@ -7220,8 +7220,9 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost) if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost)
return I == RetI ? RedCost : 0; return I == RetI ? RedCost : 0;
} else { } else {
// Matched reduce(mul())
InstructionCost MulCost = InstructionCost MulCost =
TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind); TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getExtendedAddReductionCost( InstructionCost RedCost = TTI.getExtendedAddReductionCost(
/*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy, /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy,