From 4583103460a26a760de65945b7d7a716c768620b Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Sun, 20 Oct 2019 19:38:50 +0000 Subject: [PATCH] [InstCombine] Shift amount reassociation in shifty sign bit test (PR43595) Summary: This problem consists of several parts: * Basic sign bit extraction - `trunc? (?shr %x, (bitwidth(x)-1))`. This is trivial, and easy to do, we have a fold for it. * Shift amount reassociation - if we have two identical shifts, and we can simplify-add their shift amounts together, then we likely can just perform them as a single shift. But this is finicky, has one-use restrictions, and shift opcodes must be identical. But there is a super-pattern where both of these work together. to produce sign bit test from two shifts + comparison. We do indeed already handle this in most cases. But since we get that fold transitively, it has one-use restrictions. And what's worse, in this case the right-shifts aren't required to be identical, and we can't handle that transitively: If the total shift amount is bitwidth-1, only a sign bit will remain in the output value. But if we look at this from the perspective of two shifts, we can't fold - we can't possibly know what bit pattern we'd produce via two shifts, it will be *some* kind of a mask produced from original sign bit, but we just can't tell it's shape: https://rise4fun.com/Alive/cM0 https://rise4fun.com/Alive/9IN But it will *only* contain sign bit and zeros. So from the perspective of sign bit test, we're good: https://rise4fun.com/Alive/FRz https://rise4fun.com/Alive/qBU Superb! So the simplest solution is to extend `reassociateShiftAmtsOfTwoSameDirectionShifts()` to also have a sudo-analysis mode that will ignore extra-uses, and will only check whether a) those are two right shifts and b) they end up with bitwidth(x)-1 shift amount and return either the original value that we sign-checking, or null. This does not have any functionality change for the existing `reassociateShiftAmtsOfTwoSameDirectionShifts()`. All that being said, as disscussed in the review, this yet again increases usage of instsimplify in instcombine as utility. Some day that may need to be reevaluated. https://bugs.llvm.org/show_bug.cgi?id=43595 Reviewers: spatel, efriedma, vsk Reviewed By: spatel Subscribers: xbolva00, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D68930 llvm-svn: 375371 --- .../InstCombine/InstCombineCompares.cpp | 29 ++++++---- .../InstCombine/InstCombineInternal.h | 4 ++ .../InstCombine/InstCombineShifts.cpp | 53 +++++++++++++------ ...-test-via-right-shifting-all-other-bits.ll | 12 ++--- 4 files changed, 66 insertions(+), 32 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index ee51bc03312..601295339ad 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1358,19 +1358,28 @@ Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { /// Fold equality-comparison between zero and any (maybe truncated) right-shift /// by one-less-than-bitwidth into a sign test on the original value. -Instruction *foldSignBitTest(ICmpInst &I) { +Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { + Instruction *Val; ICmpInst::Predicate Pred; - Value *X; - Constant *C; - if (!I.isEquality() || - !match(&I, m_ICmp(Pred, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))), - m_Zero()))) + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) return nullptr; - Type *XTy = X->getType(); - unsigned XBitWidth = XTy->getScalarSizeInBits(); - if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, - APInt(XBitWidth, XBitWidth - 1)))) + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else return nullptr; return ICmpInst::Create(Instruction::ICmp, diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index e04cd346b6f..4519dc0bf37 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -390,6 +390,9 @@ public: Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); + Value *reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction = false); Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( BinaryOperator &OldAShr); Instruction *visitAShr(BinaryOperator &I); @@ -912,6 +915,7 @@ private: Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); + Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index cc0e35e4a9c..64294838644 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -25,10 +25,12 @@ using namespace PatternMatch; // we should rewrite it as // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) // This is valid for any shift, but they must be identical. -static Instruction * -reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, - const SimplifyQuery &SQ, - InstCombiner::BuilderTy &Builder) { +// +// AnalyzeForSignBitExtraction indicates that we will only analyze whether this +// pattern has any 2 right-shifts that sum to 1 less than original bit width. +Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction) { // Look for a shift of some instruction, ignore zext of shift amount if any. Instruction *Sh0Op0; Value *ShAmt0; @@ -56,14 +58,25 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; - // The shift opcodes must be identical. + // We are only looking for signbit extraction if we have two right shifts. + bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && + match(Sh1, m_Shr(m_Value(), m_Value())); + // ... and if it's not two right-shifts, we know the answer already. + if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) + return nullptr; + + // The shift opcodes must be identical, unless we are just checking whether + // this pattern can be interpreted as a sign-bit-extraction. Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); - if (ShiftOpcode != Sh1->getOpcode()) + bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); + if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) return nullptr; // If we saw truncation, we'll need to produce extra instruction, - // and for that one of the operands of the shift must be one-use. - if (Trunc && !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + // and for that one of the operands of the shift must be one-use, + // unless of course we don't actually plan to produce any instructions here. + if (Trunc && !AnalyzeForSignBitExtraction && + !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) return nullptr; // Can we fold (ShAmt0+ShAmt1) ? @@ -80,14 +93,22 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, return nullptr; // FIXME: could perform constant-folding. // If there was a truncation, and we have a right-shift, we can only fold if - // we are left with the original sign bit. + // we are left with the original sign bit. Likewise, if we were just checking + // that this is a sighbit extraction, this is the place to check it. // FIXME: zero shift amount is also legal here, but we can't *easily* check // more than one predicate so it's not really worth it. - if (Trunc && ShiftOpcode != Instruction::BinaryOps::Shl && - !match(NewShAmt, - m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, - APInt(NewShAmtBitWidth, XBitWidth - 1)))) - return nullptr; + if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { + // If it's not a sign bit extraction, then we're done. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(NewShAmtBitWidth, XBitWidth - 1)))) + return nullptr; + // If it is, and that was the question, return the base value. + if (AnalyzeForSignBitExtraction) + return X; + } + + assert(IdenticalShOpcodes && "Should not get here with different shifts."); // All good, we can do this fold. NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); @@ -287,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; - if (Instruction *NewShift = - reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ, Builder)) + if (auto *NewShift = cast_or_null( + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) return NewShift; // (C1 shift (A add C2)) -> (C1 shift C2) shift A) diff --git a/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll b/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll index 0e5848e7303..8e89a0649eb 100644 --- a/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll +++ b/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll @@ -45,7 +45,7 @@ define i1 @highest_bit_test_via_lshr_with_truncation(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -107,7 +107,7 @@ define i1 @highest_bit_test_via_ashr_with_truncation(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -138,7 +138,7 @@ define i1 @highest_bit_test_via_lshr_ashr(i32 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 32, %nbits @@ -169,7 +169,7 @@ define i1 @highest_bit_test_via_lshr_ashe_with_truncation(i64 %data, i32 %nbits) ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -200,7 +200,7 @@ define i1 @highest_bit_test_via_ashr_lshr(i32 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 32, %nbits @@ -231,7 +231,7 @@ define i1 @highest_bit_test_via_ashr_lshr_with_truncation(i64 %data, i32 %nbits) ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits