From e5ca2592d42147231b91e8f3b6ee35767c8b4124 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 May 2021 15:36:55 +0000 Subject: [PATCH] [LoopNest] Consider loop nest with inner loop guard using outer loop induction variable to be perfect This patch allow more conditional branches to be considered as loop guard, and so more loop nests can be considered perfect. Reviewed By: bmahjour, sidbav Differential Revision: https://reviews.llvm.org/D94717 --- include/llvm/Analysis/LoopNestAnalysis.h | 8 +- lib/Analysis/LoopInfo.cpp | 17 +- lib/Analysis/LoopNestAnalysis.cpp | 16 +- .../LoopNestAnalysis/imperfectnest.ll | 67 -------- test/Analysis/LoopNestAnalysis/perfectnest.ll | 145 ++++++++++++++++++ unittests/Analysis/LoopInfoTest.cpp | 48 ++++++ 6 files changed, 220 insertions(+), 81 deletions(-) diff --git a/include/llvm/Analysis/LoopNestAnalysis.h b/include/llvm/Analysis/LoopNestAnalysis.h index ace17547444..e045419f8d5 100644 --- a/include/llvm/Analysis/LoopNestAnalysis.h +++ b/include/llvm/Analysis/LoopNestAnalysis.h @@ -61,10 +61,12 @@ public: static unsigned getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE); /// Recursivelly traverse all empty 'single successor' basic blocks of \p From - /// (if there are any). Return the last basic block found or \p End if it was - /// reached during the search. + /// (if there are any). When \p CheckUniquePred is set to true, check if + /// each of the empty single successors has a unique predecessor. Return + /// the last basic block found or \p End if it was reached during the search. static const BasicBlock &skipEmptyBlockUntil(const BasicBlock *From, - const BasicBlock *End); + const BasicBlock *End, + bool CheckUniquePred = false); /// Return the outermost loop in the loop nest. Loop &getOutermostLoop() const { return *Loops.front(); } diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index adb2bdb184c..b2d7edb3566 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfoImpl.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -380,10 +381,6 @@ BranchInst *Loop::getLoopGuardBranch() const { if (!ExitFromLatch) return nullptr; - BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor(); - if (!ExitFromLatchSucc) - return nullptr; - BasicBlock *GuardBB = Preheader->getUniquePredecessor(); if (!GuardBB) return nullptr; @@ -397,7 +394,17 @@ BranchInst *Loop::getLoopGuardBranch() const { BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader) ? GuardBI->getSuccessor(1) : GuardBI->getSuccessor(0); - return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr; + + // Check if ExitFromLatch (or any BasicBlock which is an empty unique + // successor of ExitFromLatch) is equal to GuardOtherSucc. If + // skipEmptyBlockUntil returns GuardOtherSucc, then the guard branch for the + // loop is GuardBI (return GuardBI), otherwise return nullptr. + if (&LoopNest::skipEmptyBlockUntil(ExitFromLatch, GuardOtherSucc, + /*CheckUniquePred=*/true) == + GuardOtherSucc) + return GuardBI; + else + return nullptr; } bool Loop::isCanonical(ScalarEvolution &SE) const { diff --git a/lib/Analysis/LoopNestAnalysis.cpp b/lib/Analysis/LoopNestAnalysis.cpp index ee74d4b0d04..2649ed60f76 100644 --- a/lib/Analysis/LoopNestAnalysis.cpp +++ b/lib/Analysis/LoopNestAnalysis.cpp @@ -206,7 +206,8 @@ unsigned LoopNest::getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE) { } const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From, - const BasicBlock *End) { + const BasicBlock *End, + bool CheckUniquePred) { assert(From && "Expecting valid From"); assert(End && "Expecting valid End"); @@ -220,8 +221,9 @@ const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From, // Visited is used to avoid running into an infinite loop. SmallPtrSet Visited; const BasicBlock *BB = From->getUniqueSuccessor(); - const BasicBlock *PredBB = BB; - while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB)) { + const BasicBlock *PredBB = From; + while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB) && + (!CheckUniquePred || BB->getUniquePredecessor())) { Visited.insert(BB); PredBB = BB; BB = BB->getUniqueSuccessor(); @@ -335,9 +337,11 @@ static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop, // Ensure the inner loop exit block lead to the outer loop latch possibly // through empty blocks. - const BasicBlock &SuccInner = - LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(), OuterLoopLatch); - if (&SuccInner != OuterLoopLatch && &SuccInner != ExtraPhiBlock) { + if ((!ExtraPhiBlock || + &LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(), + ExtraPhiBlock) != ExtraPhiBlock) && + (&LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(), + OuterLoopLatch) != OuterLoopLatch)) { DEBUG_WITH_TYPE( VerboseDebug, dbgs() << "Inner loop exit block " << *InnerLoopExit diff --git a/test/Analysis/LoopNestAnalysis/imperfectnest.ll b/test/Analysis/LoopNestAnalysis/imperfectnest.ll index 4c8066ec587..77b361bc6ba 100644 --- a/test/Analysis/LoopNestAnalysis/imperfectnest.ll +++ b/test/Analysis/LoopNestAnalysis/imperfectnest.ll @@ -424,70 +424,3 @@ for.cond.for.end13_crit_edge: for.end13: ret void } - -; Test an imperfect loop nest of the form: -; for (int i = 0; i < nx; ++i) -; if (i > 5) { // user branch -; for (int j = 1; j <= 5; j+=2) -; y[j][i] = x[i][j] + j; -; } - -define void @imperf_nest_6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { -; CHECK-LABEL: IsPerfect=false, Depth=2, OutermostLoop: imperf_nest_6_loop_i, Loops: ( imperf_nest_6_loop_i imperf_nest_6_loop_j ) -entry: - %cmp2 = icmp slt i32 0, %nx - br i1 %cmp2, label %imperf_nest_6_loop_i.lr.ph, label %for.end13 - -imperf_nest_6_loop_i.lr.ph: - br label %imperf_nest_6_loop_i - -imperf_nest_6_loop_i: - %i.0 = phi i32 [ 0, %imperf_nest_6_loop_i.lr.ph ], [ %inc12, %for.inc11 ] - %cmp1 = icmp sgt i32 %i.0, 5 - br i1 %cmp1, label %imperf_nest_6_loop_j.lr.ph, label %if.end - -imperf_nest_6_loop_j.lr.ph: - br label %imperf_nest_6_loop_j - -imperf_nest_6_loop_j: - %j.0 = phi i32 [ 1, %imperf_nest_6_loop_j.lr.ph ], [ %inc, %for.inc ] - %idxprom = sext i32 %i.0 to i64 - %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom - %0 = load i32*, i32** %arrayidx, align 8 - %idxprom5 = sext i32 %j.0 to i64 - %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 - %1 = load i32, i32* %arrayidx6, align 4 - %add = add nsw i32 %1, %j.0 - %idxprom7 = sext i32 %j.0 to i64 - %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 - %2 = load i32*, i32** %arrayidx8, align 8 - %idxprom9 = sext i32 %i.0 to i64 - %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 - store i32 %add, i32* %arrayidx10, align 4 - br label %for.inc - -for.inc: - %inc = add nsw i32 %j.0, 2 - %cmp3 = icmp sle i32 %inc, 5 - br i1 %cmp3, label %imperf_nest_6_loop_j, label %for.cond2.for.end_crit_edge - -for.cond2.for.end_crit_edge: - br label %for.end - -for.end: - br label %if.end - -if.end: - br label %for.inc11 - -for.inc11: - %inc12 = add nsw i32 %i.0, 1 - %cmp = icmp slt i32 %inc12, %nx - br i1 %cmp, label %imperf_nest_6_loop_i, label %for.cond.for.end13_crit_edge - -for.cond.for.end13_crit_edge: - br label %for.end13 - -for.end13: - ret void -} diff --git a/test/Analysis/LoopNestAnalysis/perfectnest.ll b/test/Analysis/LoopNestAnalysis/perfectnest.ll index 7593d6f1748..f8b0e6ad2c8 100644 --- a/test/Analysis/LoopNestAnalysis/perfectnest.ll +++ b/test/Analysis/LoopNestAnalysis/perfectnest.ll @@ -322,3 +322,148 @@ for.end7: %x.addr.0.lcssa = phi i32 [ %split7, %for.cond.for.end7_crit_edge ], [ %x, %entry ] ret i32 %x.addr.0.lcssa } + +; Test a perfect loop nest of the form: +; for (int i = 0; i < nx; ++i) +; if (i < ny) { // guard branch for the j-loop +; for (int j=i; j < ny; j+=1) +; y[j][i] = x[i][j] + j; +; } +define double @perf_nest_guard_branch(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { +; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 ) +; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 ) +entry: + %cmp2 = icmp slt i32 0, %nx + br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13 + +test6Loop1.lr.ph: ; preds = %entry + br label %test6Loop1 + +test6Loop1: ; preds = %test6Loop1.lr.ph, %for.inc11 + %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ] + %cmp1 = icmp slt i32 %i.0, %ny + br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end + +test6Loop2.lr.ph: ; preds = %if.then + br label %test6Loop2 + +test6Loop2: ; preds = %test6Loop2.lr.ph, %for.inc + %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ] + %idxprom = sext i32 %i.0 to i64 + %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom + %0 = load i32*, i32** %arrayidx, align 8 + %idxprom5 = sext i32 %j.0 to i64 + %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 + %1 = load i32, i32* %arrayidx6, align 4 + %add = add nsw i32 %1, %j.0 + %idxprom7 = sext i32 %j.0 to i64 + %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 + %2 = load i32*, i32** %arrayidx8, align 8 + %idxprom9 = sext i32 %i.0 to i64 + %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 + store i32 %add, i32* %arrayidx10, align 4 + br label %for.inc + +for.inc: ; preds = %test6Loop2 + %inc = add nsw i32 %j.0, 1 + %cmp3 = icmp slt i32 %inc, %ny + br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge + +for.cond2.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond2.for.end_crit_edge, %if.then + br label %if.end + +if.end: ; preds = %for.end, %test6Loop1 + br label %for.inc11 + +for.inc11: ; preds = %if.end + %inc12 = add nsw i32 %i.0, 1 + %cmp = icmp slt i32 %inc12, %nx + br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge + +for.cond.for.end13_crit_edge: ; preds = %for.inc11 + br label %for.end13 + +for.end13: ; preds = %for.cond.for.end13_crit_edge, %entry + %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0 + %3 = load i32*, i32** %arrayidx14, align 8 + %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0 + %4 = load i32, i32* %arrayidx15, align 4 + %conv = sitofp i32 %4 to double + ret double %conv +} + +; Test a perfect loop nest of the form: +; for (int i = 0; i < nx; ++i) +; if (i < ny) { // guard branch for the j-loop +; for (int j=i; j < ny; j+=1) +; y[j][i] = x[i][j] + j; +; } + +define double @test6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { +; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 ) +; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 ) +entry: + %cmp2 = icmp slt i32 0, %nx + br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13 + +test6Loop1.lr.ph: ; preds = %entry + br label %test6Loop1 + +test6Loop1: ; preds = %test6Loop1.lr.ph, %for.inc11 + %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ] + %cmp1 = icmp slt i32 %i.0, %ny + br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end + +test6Loop2.lr.ph: ; preds = %if.then + br label %test6Loop2 + +test6Loop2: ; preds = %test6Loop2.lr.ph, %for.inc + %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ] + %idxprom = sext i32 %i.0 to i64 + %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom + %0 = load i32*, i32** %arrayidx, align 8 + %idxprom5 = sext i32 %j.0 to i64 + %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 + %1 = load i32, i32* %arrayidx6, align 4 + %add = add nsw i32 %1, %j.0 + %idxprom7 = sext i32 %j.0 to i64 + %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 + %2 = load i32*, i32** %arrayidx8, align 8 + %idxprom9 = sext i32 %i.0 to i64 + %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 + store i32 %add, i32* %arrayidx10, align 4 + br label %for.inc + +for.inc: ; preds = %test6Loop2 + %inc = add nsw i32 %j.0, 1 + %cmp3 = icmp slt i32 %inc, %ny + br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge + +for.cond2.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond2.for.end_crit_edge, %if.then + br label %if.end + +if.end: ; preds = %for.end, %test6Loop1 + br label %for.inc11 + +for.inc11: ; preds = %if.end + %inc12 = add nsw i32 %i.0, 1 + %cmp = icmp slt i32 %inc12, %nx + br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge + +for.cond.for.end13_crit_edge: ; preds = %for.inc11 + br label %for.end13 + +for.end13: ; preds = %for.cond.for.end13_crit_edge, %entry + %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0 + %3 = load i32*, i32** %arrayidx14, align 8 + %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0 + %4 = load i32, i32* %arrayidx15, align 4 + %conv = sitofp i32 %4 to double + ret double %conv +} diff --git a/unittests/Analysis/LoopInfoTest.cpp b/unittests/Analysis/LoopInfoTest.cpp index bb518904e81..db6484f6928 100644 --- a/unittests/Analysis/LoopInfoTest.cpp +++ b/unittests/Analysis/LoopInfoTest.cpp @@ -1500,3 +1500,51 @@ TEST(LoopInfoTest, LoopNotRotated) { EXPECT_FALSE(L->isRotatedForm()); }); } + +TEST(LoopInfoTest, LoopUserBranch) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32* %B, i64 signext %nx, i1 %cond) {\n" + "entry:\n" + " br i1 %cond, label %bb, label %guard\n" + "guard:\n" + " %cmp.guard = icmp slt i64 0, %nx\n" + " br i1 %cmp.guard, label %for.i.preheader, label %for.end\n" + "for.i.preheader:\n" + " br label %for.i\n" + "for.i:\n" + " %i = phi i64 [ 0, %for.i.preheader ], [ %inc13, %for.i ]\n" + " %Bi = getelementptr inbounds i32, i32* %B, i64 %i\n" + " store i32 0, i32* %Bi, align 4\n" + " %inc13 = add nsw i64 %i, 1\n" + " %cmp = icmp slt i64 %inc13, %nx\n" + " br i1 %cmp, label %for.i, label %for.i.exit\n" + "for.i.exit:\n" + " br label %bb\n" + "bb:\n" + " br label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + Function::iterator FI = F.begin(); + FI = ++FI; + BasicBlock *Guard = &*FI; + assert(Guard->getName() == "guard"); + + FI = ++FI; + BasicBlock *Header = &*(++FI); + assert(Header->getName() == "for.i"); + + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + + // L should not have a guard branch + EXPECT_EQ(L->getLoopGuardBranch(), nullptr); + }); +}