diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2364202e5b6..4b3a58126d5 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1148,6 +1148,73 @@ static Instruction *foldOrToXor(BinaryOperator &I, return nullptr; } +/// Return true if a constant shift amount is always less than the specified +/// bit-width. If not, the shift could create poison in the narrower type. +static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { + if (auto *ScalarC = dyn_cast(C)) + return ScalarC->getZExtValue() < BitWidth; + + if (C->getType()->isVectorTy()) { + // Check each element of a constant vector. + unsigned NumElts = C->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return false; + if (isa(Elt)) + continue; + auto *CI = dyn_cast(Elt); + if (!CI || CI->getZExtValue() >= BitWidth) + return false; + } + return true; + } + + // The constant is a constant expression or unknown. + return false; +} + +/// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and +/// a common zext operand: and (binop (zext X), C), (zext X). +Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) { + // This transform could also apply to {or, and, xor}, but there are better + // folds for those cases, so we don't expect those patterns here. AShr is not + // handled because it should always be transformed to LShr in this sequence. + // The subtract transform is different because it has a constant on the left. + // Add/mul commute the constant to RHS; sub with constant RHS becomes add. + Value *Op0 = And.getOperand(0), *Op1 = And.getOperand(1); + Constant *C; + if (!match(Op0, m_OneUse(m_Add(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Mul(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_LShr(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Shl(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Sub(m_Constant(C), m_Specific(Op1))))) + return nullptr; + + Value *X; + if (!match(Op1, m_ZExt(m_Value(X))) || Op1->getNumUses() > 2) + return nullptr; + + Type *Ty = And.getType(); + if (!isa(Ty) && !shouldChangeType(Ty, X->getType())) + return nullptr; + + // If we're narrowing a shift, the shift amount must be safe (less than the + // width) in the narrower type. If the shift amount is greater, instsimplify + // usually handles that case, but we can't guarantee/assert it. + Instruction::BinaryOps Opc = cast(Op0)->getOpcode(); + if (Opc == Instruction::LShr || Opc == Instruction::Shl) + if (!canNarrowShiftAmt(C, X->getType()->getScalarSizeInBits())) + return nullptr; + + // and (sub C, (zext X)), (zext X) --> zext (and (sub C', X), X) + // and (binop (zext X), C), (zext X) --> zext (and (binop X, C'), X) + Value *NewC = ConstantExpr::getTrunc(C, X->getType()); + Value *NewBO = Opc == Instruction::Sub ? Builder.CreateBinOp(Opc, NewC, X) + : Builder.CreateBinOp(Opc, X, NewC); + return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1289,6 +1356,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } + if (Instruction *Z = narrowMaskedBinOp(I)) + return Z; + if (isa(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index f1f66d86cb7..3dfaad3722d 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -462,6 +462,7 @@ private: Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); diff --git a/test/Transforms/InstCombine/and-narrow.ll b/test/Transforms/InstCombine/and-narrow.ll index f5d96b1dc4d..3f801cf268a 100644 --- a/test/Transforms/InstCombine/and-narrow.ll +++ b/test/Transforms/InstCombine/and-narrow.ll @@ -5,11 +5,17 @@ ; PR35792 - https://bugs.llvm.org/show_bug.cgi?id=35792 define i16 @zext_add(i8 %x) { -; ALL-LABEL: @zext_add( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[B:%.*]] = add nuw nsw i16 [[Z]], 44 -; ALL-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_add( +; LEGAL8-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 44 +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_add( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[B:%.*]] = add nuw nsw i16 [[Z]], 44 +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = add i16 %z, 44 @@ -18,11 +24,17 @@ define i16 @zext_add(i8 %x) { } define i16 @zext_sub(i8 %x) { -; ALL-LABEL: @zext_sub( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[B:%.*]] = sub nsw i16 251, [[Z]] -; ALL-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_sub( +; LEGAL8-NEXT: [[TMP1:%.*]] = sub i8 -5, [[X:%.*]] +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_sub( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[B:%.*]] = sub nsw i16 251, [[Z]] +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = sub i16 -5, %z @@ -31,11 +43,17 @@ define i16 @zext_sub(i8 %x) { } define i16 @zext_mul(i8 %x) { -; ALL-LABEL: @zext_mul( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[B:%.*]] = mul nuw nsw i16 [[Z]], 3 -; ALL-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_mul( +; LEGAL8-NEXT: [[TMP1:%.*]] = mul i8 [[X:%.*]], 3 +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_mul( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[B:%.*]] = mul nuw nsw i16 [[Z]], 3 +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = mul i16 %z, 3 @@ -44,11 +62,17 @@ define i16 @zext_mul(i8 %x) { } define i16 @zext_lshr(i8 %x) { -; ALL-LABEL: @zext_lshr( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[B:%.*]] = lshr i16 [[Z]], 4 -; ALL-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_lshr( +; LEGAL8-NEXT: [[TMP1:%.*]] = lshr i8 [[X:%.*]], 4 +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_lshr( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[B:%.*]] = lshr i16 [[Z]], 4 +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = lshr i16 %z, 4 @@ -57,11 +81,17 @@ define i16 @zext_lshr(i8 %x) { } define i16 @zext_ashr(i8 %x) { -; ALL-LABEL: @zext_ashr( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[TMP1:%.*]] = lshr i16 [[Z]], 2 -; ALL-NEXT: [[R:%.*]] = and i16 [[TMP1]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_ashr( +; LEGAL8-NEXT: [[TMP1:%.*]] = lshr i8 [[X:%.*]], 2 +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_ashr( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[TMP1:%.*]] = lshr i16 [[Z]], 2 +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[TMP1]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = ashr i16 %z, 2 @@ -70,11 +100,17 @@ define i16 @zext_ashr(i8 %x) { } define i16 @zext_shl(i8 %x) { -; ALL-LABEL: @zext_shl( -; ALL-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 -; ALL-NEXT: [[B:%.*]] = shl nuw nsw i16 [[Z]], 3 -; ALL-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] -; ALL-NEXT: ret i16 [[R]] +; LEGAL8-LABEL: @zext_shl( +; LEGAL8-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 3 +; LEGAL8-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[X]] +; LEGAL8-NEXT: [[R:%.*]] = zext i8 [[TMP2]] to i16 +; LEGAL8-NEXT: ret i16 [[R]] +; +; LEGAL16-LABEL: @zext_shl( +; LEGAL16-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; LEGAL16-NEXT: [[B:%.*]] = shl nuw nsw i16 [[Z]], 3 +; LEGAL16-NEXT: [[R:%.*]] = and i16 [[B]], [[Z]] +; LEGAL16-NEXT: ret i16 [[R]] ; %z = zext i8 %x to i16 %b = shl i16 %z, 3 @@ -84,9 +120,9 @@ define i16 @zext_shl(i8 %x) { define <2 x i16> @zext_add_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_add_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = add nuw nsw <2 x i16> [[Z]], -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = add <2 x i8> [[X:%.*]], +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16> @@ -97,9 +133,9 @@ define <2 x i16> @zext_add_vec(<2 x i8> %x) { define <2 x i16> @zext_sub_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_sub_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = sub nuw nsw <2 x i16> , [[Z]] -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = sub <2 x i8> , [[X:%.*]] +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16> @@ -110,9 +146,9 @@ define <2 x i16> @zext_sub_vec(<2 x i8> %x) { define <2 x i16> @zext_mul_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_mul_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = mul nsw <2 x i16> [[Z]], -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = mul <2 x i8> [[X:%.*]], +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16> @@ -123,9 +159,9 @@ define <2 x i16> @zext_mul_vec(<2 x i8> %x) { define <2 x i16> @zext_lshr_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_lshr_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = lshr <2 x i16> [[Z]], -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = lshr <2 x i8> [[X:%.*]], +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16> @@ -136,9 +172,9 @@ define <2 x i16> @zext_lshr_vec(<2 x i8> %x) { define <2 x i16> @zext_ashr_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_ashr_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = lshr <2 x i16> [[Z]], -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = lshr <2 x i8> [[X:%.*]], +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16> @@ -149,9 +185,9 @@ define <2 x i16> @zext_ashr_vec(<2 x i8> %x) { define <2 x i16> @zext_shl_vec(<2 x i8> %x) { ; ALL-LABEL: @zext_shl_vec( -; ALL-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> -; ALL-NEXT: [[B:%.*]] = shl <2 x i16> [[Z]], -; ALL-NEXT: [[R:%.*]] = and <2 x i16> [[B]], [[Z]] +; ALL-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], +; ALL-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], [[X]] +; ALL-NEXT: [[R:%.*]] = zext <2 x i8> [[TMP2]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[R]] ; %z = zext <2 x i8> %x to <2 x i16>