1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 12:12:47 +01:00

[InstCombine] add helper function for foldICmpWithConstant; NFC

This is a big glob of transforms that probably should work for vectors,
but currently they are disallowed because of ConstantInt guards.

llvm-svn: 281614
This commit is contained in:
Sanjay Patel 2016-09-15 14:37:50 +00:00
parent f14620ca02
commit 94f49655d4
2 changed files with 223 additions and 210 deletions

View File

@ -1283,6 +1283,225 @@ Instruction *InstCombiner::foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A,
return getConstant(false);
}
/// The caller has matched a pattern of the form:
/// I = icmp ugt (add (add A, B), CI2), CI1
/// If this is of the form:
/// sum = a + b
/// if (sum+128 >u 255)
/// Then replace it with llvm.sadd.with.overflow.i8.
///
static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
ConstantInt *CI2, ConstantInt *CI1,
InstCombiner &IC) {
// The transformation we're trying to do here is to transform this into an
// llvm.sadd.with.overflow. To do this, we have to replace the original add
// with a narrower add, and discard the add-with-constant that is part of the
// range check (if we can't eliminate it, this isn't profitable).
// In order to eliminate the add-with-constant, the compare can be its only
// use.
Instruction *AddWithCst = cast<Instruction>(I.getOperand(0));
if (!AddWithCst->hasOneUse())
return nullptr;
// If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow.
if (!CI2->getValue().isPowerOf2())
return nullptr;
unsigned NewWidth = CI2->getValue().countTrailingZeros();
if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31)
return nullptr;
// The width of the new add formed is 1 more than the bias.
++NewWidth;
// Check to see that CI1 is an all-ones value with NewWidth bits.
if (CI1->getBitWidth() == NewWidth ||
CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth))
return nullptr;
// This is only really a signed overflow check if the inputs have been
// sign-extended; check for that condition. For example, if CI2 is 2^31 and
// the operands of the add are 64 bits wide, we need at least 33 sign bits.
unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits ||
IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits)
return nullptr;
// In order to replace the original add with a narrower
// llvm.sadd.with.overflow, the only uses allowed are the add-with-constant
// and truncates that discard the high bits of the add. Verify that this is
// the case.
Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0));
for (User *U : OrigAdd->users()) {
if (U == AddWithCst)
continue;
// Only accept truncates for now. We would really like a nice recursive
// predicate like SimplifyDemandedBits, but which goes downwards the use-def
// chain to see which bits of a value are actually demanded. If the
// original add had another add which was then immediately truncated, we
// could still do the transformation.
TruncInst *TI = dyn_cast<TruncInst>(U);
if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth)
return nullptr;
}
// If the pattern matches, truncate the inputs to the narrower type and
// use the sadd_with_overflow intrinsic to efficiently compute both the
// result and the overflow bit.
Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
Value *F = Intrinsic::getDeclaration(I.getModule(),
Intrinsic::sadd_with_overflow, NewType);
InstCombiner::BuilderTy *Builder = IC.Builder;
// Put the new code above the original add, in case there are any uses of the
// add between the add and the compare.
Builder->SetInsertPoint(OrigAdd);
Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc");
Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc");
CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd");
Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result");
Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType());
// The inner add was the result of the narrow add, zero extended to the
// wider type. Replace it with the result computed by the intrinsic.
IC.replaceInstUsesWith(*OrigAdd, ZExt);
// The original icmp gets replaced with the overflow value.
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
}
// Fold icmp Pred X, C.
Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = nullptr, *B = nullptr;
// Match the following pattern, which is a common idiom when writing
// overflow-safe integer arithmetic function. The source performs an
// addition in wider type, and explicitly checks for overflow using
// comparisons against INT_MIN and INT_MAX. Simplify this by using the
// sadd_with_overflow intrinsic.
//
// TODO: This could probably be generalized to handle other overflow-safe
// operations if we worked out the formulas to compute the appropriate
// magic constants.
//
// sum = a + b
// if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8
{
ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI
if (I.getPredicate() == ICmpInst::ICMP_UGT &&
match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2))))
if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this))
return Res;
}
// (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0)
if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT)
if (auto *SI = dyn_cast<SelectInst>(Op0)) {
SelectPatternResult SPR = matchSelectPattern(SI, A, B);
if (SPR.Flavor == SPF_SMIN) {
if (isKnownPositive(A, DL))
return new ICmpInst(I.getPredicate(), B, CI);
if (isKnownPositive(B, DL))
return new ICmpInst(I.getPredicate(), A, CI);
}
}
// The following transforms are only 'worth it' if the only user of the
// subtraction is the icmp.
if (Op0->hasOneUse()) {
// (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B)
if (I.isEquality() && CI->isZero() &&
match(Op0, m_Sub(m_Value(A), m_Value(B))))
return new ICmpInst(I.getPredicate(), A, B);
// (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B)
if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SGE, A, B);
// (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B)
if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SGT, A, B);
// (icmp slt (sub nsw A B), 0) -> (icmp slt A, B)
if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SLT, A, B);
// (icmp slt (sub nsw A B), 1) -> (icmp sle A, B)
if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SLE, A, B);
}
if (I.isEquality()) {
ConstantInt *CI2;
if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
// (icmp eq/ne (ashr/lshr const2, A), const1)
if (Instruction *Inst = foldICmpCstShrConst(I, Op0, A, CI, CI2))
return Inst;
}
if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) {
// (icmp eq/ne (shl const2, A), const1)
if (Instruction *Inst = foldICmpCstShlConst(I, Op0, A, CI, CI2))
return Inst;
}
}
// Canonicalize icmp instructions based on dominating conditions.
BasicBlock *Parent = I.getParent();
BasicBlock *Dom = Parent->getSinglePredecessor();
auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr;
ICmpInst::Predicate Pred;
BasicBlock *TrueBB, *FalseBB;
ConstantInt *CI2;
if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)),
TrueBB, FalseBB)) &&
TrueBB != FalseBB) {
ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(),
CI->getValue());
ConstantRange DominatingCR =
(Parent == TrueBB)
? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue())
: ConstantRange::makeExactICmpRegion(
CmpInst::getInversePredicate(Pred), CI2->getValue());
ConstantRange Intersection = DominatingCR.intersectWith(CR);
ConstantRange Difference = DominatingCR.difference(CR);
if (Intersection.isEmptySet())
return replaceInstUsesWith(I, Builder->getFalse());
if (Difference.isEmptySet())
return replaceInstUsesWith(I, Builder->getTrue());
// If this is a normal comparison, it demands all bits. If it is a sign
// bit comparison, it only demands the sign bit.
bool UnusedBit;
bool IsSignBit =
isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
// Canonicalizing a sign bit comparison that gets used in a branch,
// pessimizes codegen by generating branch on zero instruction instead
// of a test and branch. So we avoid canonicalizing in such situations
// because test and branch instruction has better branch displacement
// than compare and branch instruction.
if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) {
if (auto *AI = Intersection.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI));
if (auto *AD = Difference.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD));
}
}
}
return nullptr;
}
/// Fold icmp (trunc X, Y), C.
Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
Instruction *Trunc,
@ -2528,92 +2747,6 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) {
return BinaryOperator::CreateNot(Result);
}
/// The caller has matched a pattern of the form:
/// I = icmp ugt (add (add A, B), CI2), CI1
/// If this is of the form:
/// sum = a + b
/// if (sum+128 >u 255)
/// Then replace it with llvm.sadd.with.overflow.i8.
///
static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
ConstantInt *CI2, ConstantInt *CI1,
InstCombiner &IC) {
// The transformation we're trying to do here is to transform this into an
// llvm.sadd.with.overflow. To do this, we have to replace the original add
// with a narrower add, and discard the add-with-constant that is part of the
// range check (if we can't eliminate it, this isn't profitable).
// In order to eliminate the add-with-constant, the compare can be its only
// use.
Instruction *AddWithCst = cast<Instruction>(I.getOperand(0));
if (!AddWithCst->hasOneUse()) return nullptr;
// If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow.
if (!CI2->getValue().isPowerOf2()) return nullptr;
unsigned NewWidth = CI2->getValue().countTrailingZeros();
if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr;
// The width of the new add formed is 1 more than the bias.
++NewWidth;
// Check to see that CI1 is an all-ones value with NewWidth bits.
if (CI1->getBitWidth() == NewWidth ||
CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth))
return nullptr;
// This is only really a signed overflow check if the inputs have been
// sign-extended; check for that condition. For example, if CI2 is 2^31 and
// the operands of the add are 64 bits wide, we need at least 33 sign bits.
unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits ||
IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits)
return nullptr;
// In order to replace the original add with a narrower
// llvm.sadd.with.overflow, the only uses allowed are the add-with-constant
// and truncates that discard the high bits of the add. Verify that this is
// the case.
Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0));
for (User *U : OrigAdd->users()) {
if (U == AddWithCst) continue;
// Only accept truncates for now. We would really like a nice recursive
// predicate like SimplifyDemandedBits, but which goes downwards the use-def
// chain to see which bits of a value are actually demanded. If the
// original add had another add which was then immediately truncated, we
// could still do the transformation.
TruncInst *TI = dyn_cast<TruncInst>(U);
if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth)
return nullptr;
}
// If the pattern matches, truncate the inputs to the narrower type and
// use the sadd_with_overflow intrinsic to efficiently compute both the
// result and the overflow bit.
Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
Value *F = Intrinsic::getDeclaration(I.getModule(),
Intrinsic::sadd_with_overflow, NewType);
InstCombiner::BuilderTy *Builder = IC.Builder;
// Put the new code above the original add, in case there are any uses of the
// add between the add and the compare.
Builder->SetInsertPoint(OrigAdd);
Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc");
Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc");
CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd");
Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result");
Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType());
// The inner add was the result of the narrow add, zero extended to the
// wider type. Replace it with the result computed by the intrinsic.
IC.replaceInstUsesWith(*OrigAdd, ZExt);
// The original icmp gets replaced with the overflow value.
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
}
bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS,
Value *RHS, Instruction &OrigI,
Value *&Result, Constant *&Overflow) {
@ -3406,6 +3539,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (isa<UndefValue>(Elt))
continue;
// Bail out if we can't determine if this constant is min/max or if we
// know that this constant is min/max.
auto *CI = dyn_cast<ConstantInt>(Elt);
@ -3509,130 +3643,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I))
return NewICmp;
// See if we are doing a comparison with a constant.
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = nullptr, *B = nullptr;
// Match the following pattern, which is a common idiom when writing
// overflow-safe integer arithmetic function. The source performs an
// addition in wider type, and explicitly checks for overflow using
// comparisons against INT_MIN and INT_MAX. Simplify this by using the
// sadd_with_overflow intrinsic.
//
// TODO: This could probably be generalized to handle other overflow-safe
// operations if we worked out the formulas to compute the appropriate
// magic constants.
//
// sum = a + b
// if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8
{
ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI
if (I.getPredicate() == ICmpInst::ICMP_UGT &&
match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2))))
if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this))
return Res;
}
// (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0)
if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT)
if (auto *SI = dyn_cast<SelectInst>(Op0)) {
SelectPatternResult SPR = matchSelectPattern(SI, A, B);
if (SPR.Flavor == SPF_SMIN) {
if (isKnownPositive(A, DL))
return new ICmpInst(I.getPredicate(), B, CI);
if (isKnownPositive(B, DL))
return new ICmpInst(I.getPredicate(), A, CI);
}
}
// The following transforms are only 'worth it' if the only user of the
// subtraction is the icmp.
if (Op0->hasOneUse()) {
// (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B)
if (I.isEquality() && CI->isZero() &&
match(Op0, m_Sub(m_Value(A), m_Value(B))))
return new ICmpInst(I.getPredicate(), A, B);
// (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B)
if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SGE, A, B);
// (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B)
if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SGT, A, B);
// (icmp slt (sub nsw A B), 0) -> (icmp slt A, B)
if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SLT, A, B);
// (icmp slt (sub nsw A B), 1) -> (icmp sle A, B)
if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() &&
match(Op0, m_NSWSub(m_Value(A), m_Value(B))))
return new ICmpInst(ICmpInst::ICMP_SLE, A, B);
}
if (I.isEquality()) {
ConstantInt *CI2;
if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
// (icmp eq/ne (ashr/lshr const2, A), const1)
if (Instruction *Inst = foldICmpCstShrConst(I, Op0, A, CI, CI2))
return Inst;
}
if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) {
// (icmp eq/ne (shl const2, A), const1)
if (Instruction *Inst = foldICmpCstShlConst(I, Op0, A, CI, CI2))
return Inst;
}
}
// Canonicalize icmp instructions based on dominating conditions.
BasicBlock *Parent = I.getParent();
BasicBlock *Dom = Parent->getSinglePredecessor();
auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr;
ICmpInst::Predicate Pred;
BasicBlock *TrueBB, *FalseBB;
ConstantInt *CI2;
if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)),
TrueBB, FalseBB)) &&
TrueBB != FalseBB) {
ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(),
CI->getValue());
ConstantRange DominatingCR =
(Parent == TrueBB)
? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue())
: ConstantRange::makeExactICmpRegion(
CmpInst::getInversePredicate(Pred), CI2->getValue());
ConstantRange Intersection = DominatingCR.intersectWith(CR);
ConstantRange Difference = DominatingCR.difference(CR);
if (Intersection.isEmptySet())
return replaceInstUsesWith(I, Builder->getFalse());
if (Difference.isEmptySet())
return replaceInstUsesWith(I, Builder->getTrue());
// If this is a normal comparison, it demands all bits. If it is a sign
// bit comparison, it only demands the sign bit.
bool UnusedBit;
bool IsSignBit =
isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
// Canonicalizing a sign bit comparison that gets used in a branch,
// pessimizes codegen by generating branch on zero instruction instead
// of a test and branch. So we avoid canonicalizing in such situations
// because test and branch instruction has better branch displacement
// than compare and branch instruction.
if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) {
if (auto *AI = Intersection.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI));
if (auto *AD = Difference.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD));
}
}
}
if (Instruction *Res = foldICmpWithConstant(I))
return Res;
if (Instruction *Res = foldICmpUsingKnownBits(I))
return Res;

View File

@ -557,6 +557,7 @@ private:
Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
Instruction *foldICmpWithConstant(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,