From 59f49118ec9d9e953297425fb44a04de09ce9564 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Tue, 17 May 2016 00:57:57 +0000 Subject: [PATCH] [InstCombine] check vector elements before trying to transform LE/GE vector icmp (PR27756) Fix a bug introduced with rL269426 : [InstCombine] canonicalize* LE/GE vector integer comparisons to LT/GT (PR26701, PR26819) We were assuming that a ConstantDataVector / ConstantVector / ConstantAggregateZero operand of an ICMP was composed of ConstantInt elements, but it might have ConstantExpr or UndefValue elements. Handle those appropriately. Also, refactor this function to join the scalar and vector paths and eliminate the switches. Differential Revision: http://reviews.llvm.org/D20289 llvm-svn: 269728 --- .../InstCombine/InstCombineCompares.cpp | 120 ++++++------------ test/Transforms/InstCombine/icmp-vec.ll | 22 ++++ 2 files changed, 64 insertions(+), 78 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index e06ec3945e3..1ecd7aeb549 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3102,90 +3102,54 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. -static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && + Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) + return nullptr; + Value *Op0 = I.getOperand(0); Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast(Op1); + if (!Op1C) + return nullptr; - if (auto *Op1C = dyn_cast(Op1)) { - // For scalars, SimplifyICmpInst has already handled the edge cases for us, - // so we just assert on them. - APInt Op1Val = Op1C->getValue(); - switch (I.getPredicate()) { - case ICmpInst::ICMP_ULE: - assert(!Op1C->isMaxValue(false)); // A <=u MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_ULT, Op0, Builder.getInt(Op1Val + 1)); - case ICmpInst::ICMP_SLE: - assert(!Op1C->isMaxValue(true)); // A <=s MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Builder.getInt(Op1Val + 1)); - case ICmpInst::ICMP_UGE: - assert(!Op1C->isMinValue(false)); // A >=u MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_UGT, Op0, Builder.getInt(Op1Val - 1)); - case ICmpInst::ICMP_SGE: - assert(!Op1C->isMinValue(true)); // A >=s MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, Builder.getInt(Op1Val - 1)); - default: - return nullptr; - } - } - - // The usual vector types are ConstantDataVector. Exotic vector types are - // ConstantVector. Zeros are special. They all derive from Constant. - if (isa(Op1) || isa(Op1) || - isa(Op1)) { - Constant *Op1C = cast(Op1); - Type *Op1Type = Op1->getType(); - unsigned NumElts = Op1Type->getVectorNumElements(); - - // Set the new comparison predicate and splat a vector of 1 or -1 to - // increment or decrement the vector constants. But first, check that no - // elements of the constant vector would overflow/underflow when we - // increment/decrement the constants. - // + // Check if the constant operand can be safely incremented/decremented without + // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled + // the edge cases for us, so we just assert on them. For vectors, we must + // handle the edge cases. + Type *Op1Type = Op1->getType(); + bool IsSigned = I.isSigned(); + bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); + if (auto *CI = dyn_cast(Op1C)) { + // A <= MAX -> TRUE ; A >= MIN -> TRUE + assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); + } else if (Op1Type->isVectorTy()) { // TODO? If the edge cases for vectors were guaranteed to be handled as they - // are for scalar, we could remove the min/max checks here. However, to do - // that, we would have to use insertelement/shufflevector to replace edge - // values. - - CmpInst::Predicate NewPred; - Constant *OneOrNegOne = nullptr; - switch (I.getPredicate()) { - case ICmpInst::ICMP_ULE: - for (unsigned i = 0; i != NumElts; ++i) - if (cast(Op1C->getAggregateElement(i))->isMaxValue(false)) - return nullptr; - NewPred = ICmpInst::ICMP_ULT; - OneOrNegOne = ConstantInt::get(Op1Type, 1); - break; - case ICmpInst::ICMP_SLE: - for (unsigned i = 0; i != NumElts; ++i) - if (cast(Op1C->getAggregateElement(i))->isMaxValue(true)) - return nullptr; - NewPred = ICmpInst::ICMP_SLT; - OneOrNegOne = ConstantInt::get(Op1Type, 1); - break; - case ICmpInst::ICMP_UGE: - for (unsigned i = 0; i != NumElts; ++i) - if (cast(Op1C->getAggregateElement(i))->isMinValue(false)) - return nullptr; - NewPred = ICmpInst::ICMP_UGT; - OneOrNegOne = ConstantInt::get(Op1Type, -1); - break; - case ICmpInst::ICMP_SGE: - for (unsigned i = 0; i != NumElts; ++i) - if (cast(Op1C->getAggregateElement(i))->isMinValue(true)) - return nullptr; - NewPred = ICmpInst::ICMP_SGT; - OneOrNegOne = ConstantInt::get(Op1Type, -1); - break; - default: - return nullptr; + // are for scalar, we could remove the min/max checks. However, to do that, + // we would have to use insertelement/shufflevector to replace edge values. + unsigned NumElts = Op1Type->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (isa(Elt)) + continue; + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast(Elt); + if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) + return nullptr; } - - return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); + } else { + // ConstantExpr? + return nullptr; } - return nullptr; + // Increment or decrement the constant and set the new comparison predicate: + // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT + Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1); + CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; + NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; + return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); } Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { @@ -3271,7 +3235,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } - if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I, *Builder)) + if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; unsigned BitWidth = 0; diff --git a/test/Transforms/InstCombine/icmp-vec.ll b/test/Transforms/InstCombine/icmp-vec.ll index df653caa56d..290eacebe9d 100644 --- a/test/Transforms/InstCombine/icmp-vec.ll +++ b/test/Transforms/InstCombine/icmp-vec.ll @@ -159,3 +159,25 @@ define <2 x i1> @ule_max(<2 x i3> %x) { ret <2 x i1> %cmp } +; If we can't determine if a constant element is min/max (eg, it's a ConstantExpr), do nothing. + +define <2 x i1> @PR27756_1(<2 x i8> %a) { +; CHECK-LABEL: @PR27756_1( +; CHECK-NEXT: [[CMP:%.*]] = icmp sle <2 x i8> %a, to i8), i8 0> +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %cmp = icmp sle <2 x i8> %a, to i8), i8 0> + ret <2 x i1> %cmp +} + +; Undef elements don't prevent the transform of the comparison. + +define <2 x i1> @PR27756_2(<2 x i8> %a) { +; CHECK-LABEL: @PR27756_2( +; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i8> %a, +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %cmp = icmp sle <2 x i8> %a, + ret <2 x i1> %cmp +} +