diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index e744dde78d3..2de9aaceec6 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1283,6 +1283,225 @@ Instruction *InstCombiner::foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } +/// The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) + return nullptr; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) + return nullptr; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) + return nullptr; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return nullptr; + + // This is only really a signed overflow check if the inputs have been + // sign-extended; check for that condition. For example, if CI2 is 2^31 and + // the operands of the add are 64 bits wide, we need at least 33 sign bits. + unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + return nullptr; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast(AddWithCst->getOperand(0)); + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) + continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(I.getModule(), + Intrinsic::sadd_with_overflow, NewType); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.replaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +// Fold icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (ConstantInt *CI = dyn_cast(Op1)) { + Value *A = nullptr, *B = nullptr; + + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic function. The source performs an + // addition in wider type, and explicitly checks for overflow using + // comparisons against INT_MIN and INT_MAX. Simplify this by using the + // sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate + // magic constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (I.getPredicate() == ICmpInst::ICMP_UGT && + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) + return Res; + } + + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) + if (auto *SI = dyn_cast(Op0)) { + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(I.getPredicate(), B, CI); + if (isKnownPositive(B, DL)) + return new ICmpInst(I.getPredicate(), A, CI); + } + } + + // The following transforms are only 'worth it' if the only user of the + // subtraction is the icmp. + if (Op0->hasOneUse()) { + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) + if (I.isEquality() && CI->isZero() && + match(Op0, m_Sub(m_Value(A), m_Value(B)))) + return new ICmpInst(I.getPredicate(), A, B); + + // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGE, A, B); + + // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGT, A, B); + + // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLT, A, B); + + // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLE, A, B); + } + + if (I.isEquality()) { + ConstantInt *CI2; + if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || + match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (ashr/lshr const2, A), const1) + if (Instruction *Inst = foldICmpCstShrConst(I, Op0, A, CI, CI2)) + return Inst; + } + if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (shl const2, A), const1) + if (Instruction *Inst = foldICmpCstShlConst(I, Op0, A, CI, CI2)) + return Inst; + } + } + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = I.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), + CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(I, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(I, Builder->getTrue()); + + // If this is a normal comparison, it demands all bits. If it is a sign + // bit comparison, it only demands the sign bit. + bool UnusedBit; + bool IsSignBit = + isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit); + + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); + } + } + } + + return nullptr; +} + /// Fold icmp (trunc X, Y), C. Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, @@ -2528,92 +2747,6 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -/// The caller has matched a pattern of the form: -/// I = icmp ugt (add (add A, B), CI2), CI1 -/// If this is of the form: -/// sum = a + b -/// if (sum+128 >u 255) -/// Then replace it with llvm.sadd.with.overflow.i8. -/// -static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, - ConstantInt *CI2, ConstantInt *CI1, - InstCombiner &IC) { - // The transformation we're trying to do here is to transform this into an - // llvm.sadd.with.overflow. To do this, we have to replace the original add - // with a narrower add, and discard the add-with-constant that is part of the - // range check (if we can't eliminate it, this isn't profitable). - - // In order to eliminate the add-with-constant, the compare can be its only - // use. - Instruction *AddWithCst = cast(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return nullptr; - - // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; - - // The width of the new add formed is 1 more than the bias. - ++NewWidth; - - // Check to see that CI1 is an all-ones value with NewWidth bits. - if (CI1->getBitWidth() == NewWidth || - CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return nullptr; - - // This is only really a signed overflow check if the inputs have been - // sign-extended; check for that condition. For example, if CI2 is 2^31 and - // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) - return nullptr; - - // In order to replace the original add with a narrower - // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant - // and truncates that discard the high bits of the add. Verify that this is - // the case. - Instruction *OrigAdd = cast(AddWithCst->getOperand(0)); - for (User *U : OrigAdd->users()) { - if (U == AddWithCst) continue; - - // Only accept truncates for now. We would really like a nice recursive - // predicate like SimplifyDemandedBits, but which goes downwards the use-def - // chain to see which bits of a value are actually demanded. If the - // original add had another add which was then immediately truncated, we - // could still do the transformation. - TruncInst *TI = dyn_cast(U); - if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) - return nullptr; - } - - // If the pattern matches, truncate the inputs to the narrower type and - // use the sadd_with_overflow intrinsic to efficiently compute both the - // result and the overflow bit. - Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); - - InstCombiner::BuilderTy *Builder = IC.Builder; - - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - Builder->SetInsertPoint(OrigAdd); - - Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); - Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); - Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); - Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); - - // The inner add was the result of the narrow add, zero extended to the - // wider type. Replace it with the result computed by the intrinsic. - IC.replaceInstUsesWith(*OrigAdd, ZExt); - - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "sadd.overflow"); -} - bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, Value *RHS, Instruction &OrigI, Value *&Result, Constant *&Overflow) { @@ -3406,6 +3539,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (isa(Elt)) continue; + // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast(Elt); @@ -3509,130 +3643,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; - // See if we are doing a comparison with a constant. - if (ConstantInt *CI = dyn_cast(Op1)) { - Value *A = nullptr, *B = nullptr; - - // Match the following pattern, which is a common idiom when writing - // overflow-safe integer arithmetic function. The source performs an - // addition in wider type, and explicitly checks for overflow using - // comparisons against INT_MIN and INT_MAX. Simplify this by using the - // sadd_with_overflow intrinsic. - // - // TODO: This could probably be generalized to handle other overflow-safe - // operations if we worked out the formulas to compute the appropriate - // magic constants. - // - // sum = a + b - // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 - { - ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI - if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) - if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) - return Res; - } - - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) - if (auto *SI = dyn_cast(Op0)) { - SelectPatternResult SPR = matchSelectPattern(SI, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL)) - return new ICmpInst(I.getPredicate(), B, CI); - if (isKnownPositive(B, DL)) - return new ICmpInst(I.getPredicate(), A, CI); - } - } - - - // The following transforms are only 'worth it' if the only user of the - // subtraction is the icmp. - if (Op0->hasOneUse()) { - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - - // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGE, A, B); - - // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGT, A, B); - - // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLT, A, B); - - // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLE, A, B); - } - - if (I.isEquality()) { - ConstantInt *CI2; - if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || - match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (ashr/lshr const2, A), const1) - if (Instruction *Inst = foldICmpCstShrConst(I, Op0, A, CI, CI2)) - return Inst; - } - if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (shl const2, A), const1) - if (Instruction *Inst = foldICmpCstShlConst(I, Op0, A, CI, CI2)) - return Inst; - } - } - - // Canonicalize icmp instructions based on dominating conditions. - BasicBlock *Parent = I.getParent(); - BasicBlock *Dom = Parent->getSinglePredecessor(); - auto *BI = Dom ? dyn_cast(Dom->getTerminator()) : nullptr; - ICmpInst::Predicate Pred; - BasicBlock *TrueBB, *FalseBB; - ConstantInt *CI2; - if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), - TrueBB, FalseBB)) && - TrueBB != FalseBB) { - ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), - CI->getValue()); - ConstantRange DominatingCR = - (Parent == TrueBB) - ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) - : ConstantRange::makeExactICmpRegion( - CmpInst::getInversePredicate(Pred), CI2->getValue()); - ConstantRange Intersection = DominatingCR.intersectWith(CR); - ConstantRange Difference = DominatingCR.difference(CR); - if (Intersection.isEmptySet()) - return replaceInstUsesWith(I, Builder->getFalse()); - if (Difference.isEmptySet()) - return replaceInstUsesWith(I, Builder->getTrue()); - - // If this is a normal comparison, it demands all bits. If it is a sign - // bit comparison, it only demands the sign bit. - bool UnusedBit; - bool IsSignBit = - isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit); - - // Canonicalizing a sign bit comparison that gets used in a branch, - // pessimizes codegen by generating branch on zero instruction instead - // of a test and branch. So we avoid canonicalizing in such situations - // because test and branch instruction has better branch displacement - // than compare and branch instruction. - if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); - } - } - } + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 00acab2cbec..5b256cd9912 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -557,6 +557,7 @@ private: Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); + Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,