mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-25 12:12:47 +01:00
[LoopFlatten] Use SCEV and Loop APIs to identify increment and trip count
Replace pattern-matching with existing SCEV and Loop APIs as a more robust way of identifying the loop increment and trip count. Also rename 'Limit' as 'TripCount' to be consistent with terminology. Differential Revision: https://reviews.llvm.org/D106580
This commit is contained in:
parent
ee2584c4c2
commit
565fcd6a48
@ -63,7 +63,7 @@ static cl::opt<bool>
|
|||||||
AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
|
AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
|
||||||
cl::init(false),
|
cl::init(false),
|
||||||
cl::desc("Assume that the product of the two iteration "
|
cl::desc("Assume that the product of the two iteration "
|
||||||
"limits will never overflow"));
|
"trip counts will never overflow"));
|
||||||
|
|
||||||
static cl::opt<bool>
|
static cl::opt<bool>
|
||||||
WidenIV("loop-flatten-widen-iv", cl::Hidden,
|
WidenIV("loop-flatten-widen-iv", cl::Hidden,
|
||||||
@ -74,10 +74,12 @@ static cl::opt<bool>
|
|||||||
struct FlattenInfo {
|
struct FlattenInfo {
|
||||||
Loop *OuterLoop = nullptr;
|
Loop *OuterLoop = nullptr;
|
||||||
Loop *InnerLoop = nullptr;
|
Loop *InnerLoop = nullptr;
|
||||||
|
// These PHINodes correspond to loop induction variables, which are expected
|
||||||
|
// to start at zero and increment by one on each loop.
|
||||||
PHINode *InnerInductionPHI = nullptr;
|
PHINode *InnerInductionPHI = nullptr;
|
||||||
PHINode *OuterInductionPHI = nullptr;
|
PHINode *OuterInductionPHI = nullptr;
|
||||||
Value *InnerLimit = nullptr;
|
Value *InnerTripCount = nullptr;
|
||||||
Value *OuterLimit = nullptr;
|
Value *OuterTripCount = nullptr;
|
||||||
BinaryOperator *InnerIncrement = nullptr;
|
BinaryOperator *InnerIncrement = nullptr;
|
||||||
BinaryOperator *OuterIncrement = nullptr;
|
BinaryOperator *OuterIncrement = nullptr;
|
||||||
BranchInst *InnerBranch = nullptr;
|
BranchInst *InnerBranch = nullptr;
|
||||||
@ -91,12 +93,12 @@ struct FlattenInfo {
|
|||||||
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
|
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Finds the induction variable, increment and limit for a simple loop that we
|
// Finds the induction variable, increment and trip count for a simple loop that
|
||||||
// can flatten.
|
// we can flatten.
|
||||||
static bool findLoopComponents(
|
static bool findLoopComponents(
|
||||||
Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
|
Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
|
||||||
PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment,
|
PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
|
||||||
BranchInst *&BackBranch, ScalarEvolution *SE) {
|
BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
|
||||||
LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
|
LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
|
||||||
|
|
||||||
if (!L->isLoopSimplifyForm()) {
|
if (!L->isLoopSimplifyForm()) {
|
||||||
@ -104,6 +106,13 @@ static bool findLoopComponents(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Currently, to simplify the implementation, the Loop induction variable must
|
||||||
|
// start at zero and increment with a step size of one.
|
||||||
|
if (!L->isCanonical(*SE)) {
|
||||||
|
LLVM_DEBUG(dbgs() << "Loop is not canonical\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// There must be exactly one exiting block, and it must be the same at the
|
// There must be exactly one exiting block, and it must be the same at the
|
||||||
// latch.
|
// latch.
|
||||||
BasicBlock *Latch = L->getLoopLatch();
|
BasicBlock *Latch = L->getLoopLatch();
|
||||||
@ -144,40 +153,44 @@ static bool findLoopComponents(
|
|||||||
IterationInstructions.insert(Compare);
|
IterationInstructions.insert(Compare);
|
||||||
LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
|
LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
|
||||||
|
|
||||||
// Find increment and limit from the compare
|
// Find increment and trip count.
|
||||||
Increment = nullptr;
|
// There are exactly 2 incoming values to the induction phi; one from the
|
||||||
if (match(Compare->getOperand(0),
|
// pre-header and one from the latch. The incoming latch value is the
|
||||||
m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
|
// increment variable.
|
||||||
Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0));
|
Increment =
|
||||||
Limit = Compare->getOperand(1);
|
dyn_cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));
|
||||||
} else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE &&
|
if (Increment->hasNUsesOrMore(3)) {
|
||||||
match(Compare->getOperand(1),
|
LLVM_DEBUG(dbgs() << "Could not find valid increment\n");
|
||||||
m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
|
|
||||||
Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1));
|
|
||||||
Limit = Compare->getOperand(0);
|
|
||||||
}
|
|
||||||
if (!Increment || Increment->hasNUsesOrMore(3)) {
|
|
||||||
LLVM_DEBUG(dbgs() << "Cound not find valid increment\n");
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// The trip count is the RHS of the compare. If this doesn't match the trip
|
||||||
|
// count computed by SCEV then this is either because the trip count variable
|
||||||
|
// has been widened (then leave the trip count as it is), or because it is a
|
||||||
|
// constant and another transformation has changed the compare, e.g.
|
||||||
|
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten
|
||||||
|
// the loop (yet).
|
||||||
|
TripCount = Compare->getOperand(1);
|
||||||
|
const SCEV *SCEVTripCount =
|
||||||
|
SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L));
|
||||||
|
if (SE->getSCEV(TripCount) != SCEVTripCount) {
|
||||||
|
if (!IsWidened) {
|
||||||
|
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto TripCountInst = dyn_cast<Instruction>(TripCount);
|
||||||
|
if (!TripCountInst) {
|
||||||
|
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
|
||||||
|
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
|
||||||
|
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
IterationInstructions.insert(Increment);
|
IterationInstructions.insert(Increment);
|
||||||
LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
|
LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
|
||||||
LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump());
|
LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
|
||||||
|
|
||||||
assert(InductionPHI->getNumIncomingValues() == 2);
|
|
||||||
|
|
||||||
if (InductionPHI->getIncomingValueForBlock(Latch) != Increment) {
|
|
||||||
LLVM_DEBUG(
|
|
||||||
dbgs() << "Incoming value from latch is not the increment inst\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *CI = dyn_cast<ConstantInt>(
|
|
||||||
InductionPHI->getIncomingValueForBlock(L->getLoopPreheader()));
|
|
||||||
if (!CI || !CI->isZero()) {
|
|
||||||
LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
|
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
|
||||||
return true;
|
return true;
|
||||||
@ -300,7 +313,7 @@ checkOuterLoopInsts(FlattenInfo &FI,
|
|||||||
// Multiplies of the outer iteration variable and inner iteration
|
// Multiplies of the outer iteration variable and inner iteration
|
||||||
// count will be optimised out.
|
// count will be optimised out.
|
||||||
if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
|
if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
|
||||||
m_Specific(FI.InnerLimit))))
|
m_Specific(FI.InnerTripCount))))
|
||||||
continue;
|
continue;
|
||||||
InstructionCost Cost =
|
InstructionCost Cost =
|
||||||
TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
|
TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
|
||||||
@ -325,16 +338,16 @@ checkOuterLoopInsts(FlattenInfo &FI,
|
|||||||
static bool checkIVUsers(FlattenInfo &FI) {
|
static bool checkIVUsers(FlattenInfo &FI) {
|
||||||
// We require all uses of both induction variables to match this pattern:
|
// We require all uses of both induction variables to match this pattern:
|
||||||
//
|
//
|
||||||
// (OuterPHI * InnerLimit) + InnerPHI
|
// (OuterPHI * InnerTripCount) + InnerPHI
|
||||||
//
|
//
|
||||||
// Any uses of the induction variables not matching that pattern would
|
// Any uses of the induction variables not matching that pattern would
|
||||||
// require a div/mod to reconstruct in the flattened loop, so the
|
// require a div/mod to reconstruct in the flattened loop, so the
|
||||||
// transformation wouldn't be profitable.
|
// transformation wouldn't be profitable.
|
||||||
|
|
||||||
Value *InnerLimit = FI.InnerLimit;
|
Value *InnerTripCount = FI.InnerTripCount;
|
||||||
if (FI.Widened &&
|
if (FI.Widened &&
|
||||||
(isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
|
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
|
||||||
InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0);
|
InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
|
||||||
|
|
||||||
// Check that all uses of the inner loop's induction variable match the
|
// Check that all uses of the inner loop's induction variable match the
|
||||||
// expected pattern, recording the uses of the outer IV.
|
// expected pattern, recording the uses of the outer IV.
|
||||||
@ -368,7 +381,7 @@ static bool checkIVUsers(FlattenInfo &FI) {
|
|||||||
m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
|
m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
|
||||||
m_Value(MatchedItCount)));
|
m_Value(MatchedItCount)));
|
||||||
|
|
||||||
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) {
|
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
|
||||||
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
||||||
ValidOuterPHIUses.insert(MatchedMul);
|
ValidOuterPHIUses.insert(MatchedMul);
|
||||||
FI.LinearIVUses.insert(U);
|
FI.LinearIVUses.insert(U);
|
||||||
@ -417,7 +430,7 @@ static bool checkIVUsers(FlattenInfo &FI) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return an OverflowResult dependant on if overflow of the multiplication of
|
// Return an OverflowResult dependant on if overflow of the multiplication of
|
||||||
// InnerLimit and OuterLimit can be assumed not to happen.
|
// InnerTripCount and OuterTripCount can be assumed not to happen.
|
||||||
static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
|
static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
|
||||||
AssumptionCache *AC) {
|
AssumptionCache *AC) {
|
||||||
Function *F = FI.OuterLoop->getHeader()->getParent();
|
Function *F = FI.OuterLoop->getHeader()->getParent();
|
||||||
@ -430,7 +443,7 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
|
|||||||
// Check if the multiply could not overflow due to known ranges of the
|
// Check if the multiply could not overflow due to known ranges of the
|
||||||
// input values.
|
// input values.
|
||||||
OverflowResult OR = computeOverflowForUnsignedMul(
|
OverflowResult OR = computeOverflowForUnsignedMul(
|
||||||
FI.InnerLimit, FI.OuterLimit, DL, AC,
|
FI.InnerTripCount, FI.OuterTripCount, DL, AC,
|
||||||
FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
|
FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
|
||||||
if (OR != OverflowResult::MayOverflow)
|
if (OR != OverflowResult::MayOverflow)
|
||||||
return OR;
|
return OR;
|
||||||
@ -461,21 +474,23 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
|
|||||||
ScalarEvolution *SE, AssumptionCache *AC,
|
ScalarEvolution *SE, AssumptionCache *AC,
|
||||||
const TargetTransformInfo *TTI) {
|
const TargetTransformInfo *TTI) {
|
||||||
SmallPtrSet<Instruction *, 8> IterationInstructions;
|
SmallPtrSet<Instruction *, 8> IterationInstructions;
|
||||||
if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
|
if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
|
||||||
FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
|
FI.InnerInductionPHI, FI.InnerTripCount,
|
||||||
|
FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
|
||||||
return false;
|
return false;
|
||||||
if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI,
|
if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
|
||||||
FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE))
|
FI.OuterInductionPHI, FI.OuterTripCount,
|
||||||
|
FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Both of the loop limit values must be invariant in the outer loop
|
// Both of the loop trip count values must be invariant in the outer loop
|
||||||
// (non-instructions are all inherently invariant).
|
// (non-instructions are all inherently invariant).
|
||||||
if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) {
|
if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
|
||||||
LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n");
|
LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) {
|
if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
|
||||||
LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n");
|
LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -515,8 +530,8 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
|
|||||||
ORE.emit(Remark);
|
ORE.emit(Remark);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value *NewTripCount =
|
Value *NewTripCount = BinaryOperator::CreateMul(
|
||||||
BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
|
FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
|
||||||
FI.OuterLoop->getLoopPreheader()->getTerminator());
|
FI.OuterLoop->getLoopPreheader()->getTerminator());
|
||||||
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
|
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
|
||||||
NewTripCount->dump());
|
NewTripCount->dump());
|
||||||
@ -581,7 +596,7 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
|
|||||||
|
|
||||||
// If both induction types are less than the maximum legal integer width,
|
// If both induction types are less than the maximum legal integer width,
|
||||||
// promote both to the widest type available so we know calculating
|
// promote both to the widest type available so we know calculating
|
||||||
// (OuterLimit * InnerLimit) as the new trip count is safe.
|
// (OuterTripCount * InnerTripCount) as the new trip count is safe.
|
||||||
if (InnerType != OuterType ||
|
if (InnerType != OuterType ||
|
||||||
InnerType->getScalarSizeInBits() >= MaxLegalSize ||
|
InnerType->getScalarSizeInBits() >= MaxLegalSize ||
|
||||||
MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
|
MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
|
||||||
|
@ -341,6 +341,37 @@ for.end8: ; preds = %for.inc6
|
|||||||
ret i32 10
|
ret i32 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
; When the loop trip count is a constant (e.g. 20) and the step size is
|
||||||
|
; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 ->
|
||||||
|
; icmp ult i32 %j, 19. In this case a valid trip count is not found so
|
||||||
|
; the loop is not flattened.
|
||||||
|
define i32 @test9(i32* nocapture %A) {
|
||||||
|
entry:
|
||||||
|
br label %for.cond1.preheader
|
||||||
|
|
||||||
|
for.cond1.preheader:
|
||||||
|
%i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
|
||||||
|
%mul = mul i32 %i.017, 20
|
||||||
|
br label %for.body4
|
||||||
|
|
||||||
|
for.cond.cleanup3:
|
||||||
|
%inc6 = add i32 %i.017, 1
|
||||||
|
%cmp = icmp ult i32 %inc6, 11
|
||||||
|
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
|
||||||
|
|
||||||
|
for.body4:
|
||||||
|
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
|
||||||
|
%add = add i32 %j.016, %mul
|
||||||
|
%arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
|
||||||
|
store i32 30, i32* %arrayidx, align 4
|
||||||
|
%inc = add nuw nsw i32 %j.016, 1
|
||||||
|
%cmp2 = icmp ult i32 %j.016, 19
|
||||||
|
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
|
||||||
|
|
||||||
|
for.cond.cleanup:
|
||||||
|
%0 = load i32, i32* %A, align 4
|
||||||
|
ret i32 %0
|
||||||
|
}
|
||||||
|
|
||||||
; Outer loop conditional phi
|
; Outer loop conditional phi
|
||||||
define i32 @e() {
|
define i32 @e() {
|
||||||
|
Loading…
Reference in New Issue
Block a user