diff --git a/include/llvm/IR/PatternMatch.h b/include/llvm/IR/PatternMatch.h index 05fa37b1812..82e7d0af958 100644 --- a/include/llvm/IR/PatternMatch.h +++ b/include/llvm/IR/PatternMatch.h @@ -88,6 +88,11 @@ inline class_match m_Undef() { return class_match(); } /// Match an arbitrary Constant and ignore it. inline class_match m_Constant() { return class_match(); } +/// Match an arbitrary basic block value and ignore it. +inline class_match m_BasicBlock() { + return class_match(); +} + /// Inverting matcher template struct match_unless { Ty M; @@ -563,6 +568,12 @@ inline bind_ty m_Constant(Constant *&C) { return C; } /// Match a ConstantFP, capturing the value if we match. inline bind_ty m_ConstantFP(ConstantFP *&C) { return C; } +/// Match a basic block value, capturing it if we match. +inline bind_ty m_BasicBlock(BasicBlock *&V) { return V; } +inline bind_ty m_BasicBlock(const BasicBlock *&V) { + return V; +} + /// Match a specified Value*. struct specificval_ty { const Value *Val; @@ -656,6 +667,32 @@ inline specific_intval m_SpecificInt(uint64_t V) { return specific_intval(V); } /// ConstantInts wider than 64-bits. inline bind_const_intval_ty m_ConstantInt(uint64_t &V) { return V; } +/// Match a specified basic block value. +struct specific_bbval { + BasicBlock *Val; + + specific_bbval(BasicBlock *Val) : Val(Val) {} + + template bool match(ITy *V) { + const auto *BB = dyn_cast(V); + return BB && BB == Val; + } +}; + +/// Match a specific basic block value. +inline specific_bbval m_SpecificBB(BasicBlock *BB) { + return specific_bbval(BB); +} + +/// A commutative-friendly version of m_Specific(). +inline deferredval_ty m_Deferred(BasicBlock *const &BB) { + return BB; +} +inline deferredval_ty +m_Deferred(const BasicBlock *const &BB) { + return BB; +} + //===----------------------------------------------------------------------===// // Matcher for any binary operator. // @@ -1345,19 +1382,23 @@ struct brc_match { template bool match(OpTy *V) { if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) { - T = BI->getSuccessor(0); - F = BI->getSuccessor(1); - return true; - } + if (BI->isConditional() && Cond.match(BI->getCondition())) + return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); return false; } }; template -inline brc_match +inline brc_match, bind_ty> m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F) { - return brc_match(C, T, F); + return brc_match, bind_ty>( + C, m_BasicBlock(T), m_BasicBlock(F)); +} + +template +inline brc_match +m_Br(const Cond_t &C, const TrueBlock_t &T, const FalseBlock_t &F) { + return brc_match(C, T, F); } //===----------------------------------------------------------------------===// diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 251717f64e0..8070882c150 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -2105,11 +2105,10 @@ SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, for (BasicBlock *BB : ExitingBlocks) { ICmpInst::Predicate Pred; Instruction *LHS, *RHS; - BasicBlock *TrueBB, *FalseBB; if (!match(BB->getTerminator(), m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), - TrueBB, FalseBB))) + m_BasicBlock(), m_BasicBlock()))) continue; if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 5e7587b2264..6e2867f5708 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -121,14 +121,13 @@ static bool foldGuardedRotateToFunnelShift(Instruction &I) { BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1); BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1); Instruction *TermI = GuardBB->getTerminator(); - BasicBlock *TrueBB, *FalseBB; ICmpInst::Predicate Pred; - if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), TrueBB, - FalseBB))) + BasicBlock *PhiBB = Phi.getParent(); + if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), + m_SpecificBB(PhiBB), m_SpecificBB(RotBB)))) return false; - BasicBlock *PhiBB = Phi.getParent(); - if (Pred != CmpInst::ICMP_EQ || TrueBB != PhiBB || FalseBB != RotBB) + if (Pred != CmpInst::ICMP_EQ) return false; // We matched a variation of this IR pattern: diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 41cadb387e5..75774980996 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2557,9 +2557,7 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True Value *X = nullptr; - BasicBlock *TrueDest; - BasicBlock *FalseDest; - if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && !isa(X)) { // Swap Destinations and condition... BI.setCondition(X); @@ -2577,8 +2575,8 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, - FalseDest)) && + if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), + m_BasicBlock(), m_BasicBlock())) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. CmpInst *Cond = cast(BI.getCondition()); diff --git a/unittests/IR/PatternMatch.cpp b/unittests/IR/PatternMatch.cpp index fe8c518fa82..d52d4fe980f 100644 --- a/unittests/IR/PatternMatch.cpp +++ b/unittests/IR/PatternMatch.cpp @@ -1045,6 +1045,34 @@ TEST_F(PatternMatchTest, FloatingPointFNeg) { EXPECT_FALSE(match(V3, m_FNeg(m_Value(Match)))); } +TEST_F(PatternMatchTest, CondBranchTest) { + BasicBlock *TrueBB = BasicBlock::Create(Ctx, "TrueBB", F); + BasicBlock *FalseBB = BasicBlock::Create(Ctx, "FalseBB", F); + Value *Br1 = IRB.CreateCondBr(IRB.getTrue(), TrueBB, FalseBB); + + EXPECT_TRUE(match(Br1, m_Br(m_Value(), m_BasicBlock(), m_BasicBlock()))); + + BasicBlock *A, *B; + EXPECT_TRUE(match(Br1, m_Br(m_Value(), m_BasicBlock(A), m_BasicBlock(B)))); + EXPECT_EQ(TrueBB, A); + EXPECT_EQ(FalseBB, B); + + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_SpecificBB(FalseBB), m_BasicBlock()))); + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_BasicBlock(), m_SpecificBB(TrueBB)))); + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_SpecificBB(FalseBB), m_BasicBlock(TrueBB)))); + EXPECT_TRUE( + match(Br1, m_Br(m_Value(), m_SpecificBB(TrueBB), m_BasicBlock(FalseBB)))); + + // Check we can use m_Deferred with branches. + EXPECT_FALSE(match(Br1, m_Br(m_Value(), m_BasicBlock(A), m_Deferred(A)))); + Value *Br2 = IRB.CreateCondBr(IRB.getTrue(), TrueBB, TrueBB); + A = nullptr; + EXPECT_TRUE(match(Br2, m_Br(m_Value(), m_BasicBlock(A), m_Deferred(A)))); +} + template struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types,