diff --git a/include/llvm/IR/PatternMatch.h b/include/llvm/IR/PatternMatch.h index 8b1c6b2f350..dcca931b16d 100644 --- a/include/llvm/IR/PatternMatch.h +++ b/include/llvm/IR/PatternMatch.h @@ -489,6 +489,22 @@ struct specificval_ty { /// Match if we have a specific specified value. inline specificval_ty m_Specific(const Value *V) { return V; } +/// Stores a reference to the Value *, not the Value * itself, +/// thus can be used in commutative matchers. +template struct deferredval_ty { + Class *const &Val; + + deferredval_ty(Class *const &V) : Val(V) {} + + template bool match(ITy *const V) { return V == Val; } +}; + +/// A commutative-friendly version of m_Specific(). +inline deferredval_ty m_Deferred(Value *const &V) { return V; } +inline deferredval_ty m_Deferred(const Value *const &V) { + return V; +} + /// Match a specified floating point value or vector of all elements of /// that value. struct specific_fpval { @@ -562,13 +578,15 @@ struct AnyBinaryOp_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1))); + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))); return false; } }; @@ -588,20 +606,22 @@ struct BinaryOp_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { if (V->getValueID() == Value::InstructionVal + Opcode) { auto *I = cast(V); return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1))); + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))); } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || - (Commutable && R.match(CE->getOperand(0)) && - L.match(CE->getOperand(1)))); + (Commutable && L.match(CE->getOperand(1)) && + R.match(CE->getOperand(0)))); return false; } }; @@ -926,14 +946,16 @@ struct CmpClass_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1)))) { + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0)))) { Predicate = I->getPredicate(); return true; } @@ -1251,6 +1273,8 @@ struct MaxMin_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { @@ -1277,7 +1301,7 @@ struct MaxMin_match { return false; // It does! Bind the operands. return (L.match(LHS) && R.match(RHS)) || - (Commutable && R.match(LHS) && L.match(RHS)); + (Commutable && L.match(RHS) && R.match(LHS)); } }; diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index ac29cde01fa..a799f7c05b2 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -982,12 +982,9 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // matching the form add(x, add(x, y)) where y is odd. // TODO: This could be generalized to clearing any bit set in y where the // following bit is known to be unset in y. - Value *Y = nullptr; + Value *X = nullptr, *Y = nullptr; if (!Known.Zero[0] && !Known.One[0] && - (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), - m_Value(Y))) || - match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), - m_Value(Y))))) { + match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { Known2.resetAll(); computeKnownBits(Y, Known2, Depth + 1, Q); if (Known2.countMinTrailingOnes() > 0) diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index b7e115efea1..274bde0da89 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1363,26 +1363,15 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // (add (xor A, B) (and A, B)) --> (or A, B) - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - // (add (and A, B) (xor A, B)) --> (or A, B) - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } - // (add (and A, B) (or A, B)) --> (add A, B) - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 1d90d2363aa..f4aec9112e7 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1288,8 +1288,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // Operand complexity canonicalization guarantees that the 'or' is Op0. // (A | B) & ~(A & B) --> A ^ B // (A | B) & ~(B & A) --> A ^ B - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) + if (match(&I, m_BinOp(m_Or(m_Value(A), m_Value(B)), + m_Not(m_c_And(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateXor(A, B); // (A | ~B) & (~A | B) --> ~(A ^ B) @@ -1297,8 +1297,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // (~B | A) & (~A | B) --> ~(A ^ B) // (~B | A) & (B | ~A) --> ~(A ^ B) if (Op0->hasOneUse() || Op1->hasOneUse()) - if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(&I, m_BinOp(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); return nullptr; @@ -2294,10 +2294,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A & B) ^ (B | A) -> A ^ B // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) || - (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Specific(B))))) { + if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), + m_c_Or(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2307,10 +2305,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B - if ((match(Op0, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2320,10 +2316,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; diff --git a/test/Transforms/InstCombine/and-or-not.ll b/test/Transforms/InstCombine/and-or-not.ll index 2d3fb425d3e..7bd4ad7b3bb 100644 --- a/test/Transforms/InstCombine/and-or-not.ll +++ b/test/Transforms/InstCombine/and-or-not.ll @@ -333,7 +333,7 @@ define i32 @xor_to_xor3(i32 %a, i32 %b) { define i32 @xor_to_xor4(i32 %a, i32 %b) { ; CHECK-LABEL: @xor_to_xor4( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %or = or i32 %a, %b @@ -389,7 +389,7 @@ define i32 @xor_to_xor7(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor7( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -408,7 +408,7 @@ define i32 @xor_to_xor8(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor8( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -465,7 +465,7 @@ define i32 @xor_to_xor11(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor11( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -484,7 +484,7 @@ define i32 @xor_to_xor12(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor12( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 diff --git a/test/Transforms/InstCombine/or-xor.ll b/test/Transforms/InstCombine/or-xor.ll index 7d163d58600..ab5f2f8ef53 100644 --- a/test/Transforms/InstCombine/or-xor.ll +++ b/test/Transforms/InstCombine/or-xor.ll @@ -188,7 +188,7 @@ define i32 @test13(i32 %x, i32 %y) { ; ((x | ~y) ^ (~x | y)) -> x ^ y define i32 @test14(i32 %x, i32 %y) { ; CHECK-LABEL: @test14( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -201,7 +201,7 @@ define i32 @test14(i32 %x, i32 %y) { define i32 @test14_commuted(i32 %x, i32 %y) { ; CHECK-LABEL: @test14_commuted( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -215,7 +215,7 @@ define i32 @test14_commuted(i32 %x, i32 %y) { ; ((x & ~y) ^ (~x & y)) -> x ^ y define i32 @test15(i32 %x, i32 %y) { ; CHECK-LABEL: @test15( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -228,7 +228,7 @@ define i32 @test15(i32 %x, i32 %y) { define i32 @test15_commuted(i32 %x, i32 %y) { ; CHECK-LABEL: @test15_commuted( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 diff --git a/unittests/IR/PatternMatch.cpp b/unittests/IR/PatternMatch.cpp index 6bdcf4de7e2..2c6da22298f 100644 --- a/unittests/IR/PatternMatch.cpp +++ b/unittests/IR/PatternMatch.cpp @@ -65,6 +65,56 @@ TEST_F(PatternMatchTest, OneUse) { EXPECT_FALSE(m_OneUse(m_Value()).match(Leaf)); } +TEST_F(PatternMatchTest, CommutativeDeferredValue) { + Value *X = IRB.getInt32(1); + Value *Y = IRB.getInt32(2); + + { + Value *tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + const Value *tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + Value *const tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + const Value *const tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + + { + Value *tX = nullptr; + EXPECT_TRUE(match(IRB.CreateAnd(X, X), m_And(m_Value(tX), m_Deferred(tX)))); + EXPECT_EQ(tX, X); + } + { + Value *tX = nullptr; + EXPECT_FALSE( + match(IRB.CreateAnd(X, Y), m_c_And(m_Value(tX), m_Deferred(tX)))); + } + + auto checkMatch = [X, Y](Value *Pattern) { + Value *tX = nullptr, *tY = nullptr; + EXPECT_TRUE(match( + Pattern, m_c_And(m_Value(tX), m_c_And(m_Deferred(tX), m_Value(tY))))); + EXPECT_EQ(tX, X); + EXPECT_EQ(tY, Y); + }; + + checkMatch(IRB.CreateAnd(X, IRB.CreateAnd(X, Y))); + checkMatch(IRB.CreateAnd(X, IRB.CreateAnd(Y, X))); + checkMatch(IRB.CreateAnd(IRB.CreateAnd(X, Y), X)); + checkMatch(IRB.CreateAnd(IRB.CreateAnd(Y, X), X)); +} + TEST_F(PatternMatchTest, FloatingPointOrderedMin) { Type *FltTy = IRB.getFloatTy(); Value *L = ConstantFP::get(FltTy, 1.0);