From ddd62877c61700701c215a2f1d11a4a7bb1a4b57 Mon Sep 17 00:00:00 2001 From: David Green Date: Wed, 21 Jul 2021 11:34:30 +0100 Subject: [PATCH] [LV] Make use of PatternMatchers in getReductionPatternCost. NFC Pulled out of D106166, this modifies getReductionPatternCost to use PatternMatchers, hopefully simplifying the code a little. --- lib/Transforms/Vectorize/LoopVectorize.cpp | 25 +++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 4d87897fc23..001f1e45e48 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7127,6 +7127,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, Optional LoopVectorizationCostModel::getReductionPatternCost( Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) { + using namespace llvm::PatternMatch; // Early exit for no inloop reductions if (InLoopReductionChains.empty() || VF.isScalar() || !isa(Ty)) return None; @@ -7145,13 +7146,12 @@ Optional LoopVectorizationCostModel::getReductionPatternCost( // it is not we return an invalid cost specifying the orignal cost method // should be used. Instruction *RetI = I; - if ((RetI->getOpcode() == Instruction::SExt || - RetI->getOpcode() == Instruction::ZExt)) { + if (match(RetI, m_ZExtOrSExt(m_Value()))) { if (!RetI->hasOneUser()) return None; RetI = RetI->user_back(); } - if (RetI->getOpcode() == Instruction::Mul && + if (match(RetI, m_Mul(m_Value(), m_Value())) && RetI->user_back()->getOpcode() == Instruction::Add) { if (!RetI->hasOneUser()) return None; @@ -7183,8 +7183,10 @@ Optional LoopVectorizationCostModel::getReductionPatternCost( VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy); - if (RedOp && (isa(RedOp) || isa(RedOp)) && + Instruction *Op0, *Op1; + if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) && !TheLoop->isLoopInvariant(RedOp)) { + // Matched reduce(ext(A)) bool IsUnsigned = isa(RedOp); auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy); InstructionCost RedCost = TTI.getExtendedAddReductionCost( @@ -7196,22 +7198,20 @@ Optional LoopVectorizationCostModel::getReductionPatternCost( TTI::CastContextHint::None, CostKind, RedOp); if (RedCost.isValid() && RedCost < BaseCost + ExtCost) return I == RetI ? RedCost : 0; - } else if (RedOp && RedOp->getOpcode() == Instruction::Mul) { - Instruction *Mul = RedOp; - Instruction *Op0 = dyn_cast(Mul->getOperand(0)); - Instruction *Op1 = dyn_cast(Mul->getOperand(1)); - if (Op0 && Op1 && (isa(Op0) || isa(Op0)) && + } else if (RedOp && + match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) { + if (match(Op0, m_ZExtOrSExt(m_Value())) && Op0->getOpcode() == Op1->getOpcode() && Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() && !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) { bool IsUnsigned = isa(Op0); auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy); - // reduce(mul(ext, ext)) + // Matched reduce(mul(ext, ext)) InstructionCost ExtCost = TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType, TTI::CastContextHint::None, CostKind, Op0); InstructionCost MulCost = - TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind); + TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); InstructionCost RedCost = TTI.getExtendedAddReductionCost( /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, @@ -7220,8 +7220,9 @@ Optional LoopVectorizationCostModel::getReductionPatternCost( if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost) return I == RetI ? RedCost : 0; } else { + // Matched reduce(mul()) InstructionCost MulCost = - TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind); + TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); InstructionCost RedCost = TTI.getExtendedAddReductionCost( /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy,