From d9b2a4b5e792ba975d7fa27e846f4aefbe2b1b6e Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Sat, 10 Apr 2021 19:37:59 +0300 Subject: [PATCH] [NFC][ConstantRange] Add 'icmp' helper method "Does the predicate hold between two ranges?" Not very surprisingly, some places were already doing this check, without explicitly naming the algorithm, cleanup them all. --- include/llvm/Analysis/ValueLattice.h | 6 +-- include/llvm/IR/ConstantRange.h | 4 ++ include/llvm/IR/IntrinsicInst.h | 41 +++++++++++++++++ lib/Analysis/InstructionSimplify.cpp | 7 +-- lib/Analysis/ScalarEvolution.cpp | 12 ++--- lib/IR/ConstantRange.cpp | 5 ++ lib/Transforms/IPO/AttributorAttributes.cpp | 5 +- unittests/IR/ConstantRangeTest.cpp | 51 ++++++++++++++++++++- 8 files changed, 108 insertions(+), 23 deletions(-) diff --git a/include/llvm/Analysis/ValueLattice.h b/include/llvm/Analysis/ValueLattice.h index 5ff9c4a6b08..1b32fca5069 100644 --- a/include/llvm/Analysis/ValueLattice.h +++ b/include/llvm/Analysis/ValueLattice.h @@ -474,11 +474,9 @@ public: const auto &CR = getConstantRange(); const auto &OtherCR = Other.getConstantRange(); - if (ConstantRange::makeSatisfyingICmpRegion(Pred, OtherCR).contains(CR)) + if (CR.icmp(Pred, OtherCR)) return ConstantInt::getTrue(Ty); - if (ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), OtherCR) - .contains(CR)) + if (CR.icmp(CmpInst::getInversePredicate(Pred), OtherCR)) return ConstantInt::getFalse(Ty); return nullptr; diff --git a/include/llvm/IR/ConstantRange.h b/include/llvm/IR/ConstantRange.h index 20e8e67436a..44b8c395c89 100644 --- a/include/llvm/IR/ConstantRange.h +++ b/include/llvm/IR/ConstantRange.h @@ -124,6 +124,10 @@ public: static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other); + /// Does the predicate \p Pred hold between ranges this and \p Other? + /// NOTE: false does not mean that inverse predicate holds! + bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const; + /// Produce the largest range containing all X such that "X BinOp Y" is /// guaranteed not to wrap (overflow) for *all* Y in Other. However, there may /// be *some* Y in Other for which additional X not contained in the result diff --git a/include/llvm/IR/IntrinsicInst.h b/include/llvm/IR/IntrinsicInst.h index 6c825d380fc..b688ece7067 100644 --- a/include/llvm/IR/IntrinsicInst.h +++ b/include/llvm/IR/IntrinsicInst.h @@ -458,6 +458,47 @@ public: } }; +/// This class represents min/max intrinsics. +class LimitingIntrinsic : public IntrinsicInst { +public: + static bool classof(const IntrinsicInst *I) { + switch (I->getIntrinsicID()) { + case Intrinsic::umin: + case Intrinsic::umax: + case Intrinsic::smin: + case Intrinsic::smax: + return true; + default: + return false; + } + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + Value *getLHS() const { return const_cast(getArgOperand(0)); } + Value *getRHS() const { return const_cast(getArgOperand(1)); } + + /// Returns the comparison predicate underlying the intrinsic. + ICmpInst::Predicate getPredicate() const { + switch (getIntrinsicID()) { + case Intrinsic::umin: + return ICmpInst::Predicate::ICMP_ULT; + case Intrinsic::umax: + return ICmpInst::Predicate::ICMP_UGT; + case Intrinsic::smin: + return ICmpInst::Predicate::ICMP_SLT; + case Intrinsic::smax: + return ICmpInst::Predicate::ICMP_SGT; + default: + llvm_unreachable("Invalid intrinsic"); + } + } + + /// Whether the intrinsic is signed or unsigned. + bool isSigned() const { return ICmpInst::isSigned(getPredicate()); }; +}; + /// This class represents an intrinsic that is based on a binary operation. /// This includes op.with.overflow and saturating add/sub intrinsics. class BinaryOpIntrinsic : public IntrinsicInst { diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index a6d3ca64189..b233a0f3eb2 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -3451,13 +3451,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, auto LHS_CR = getConstantRangeFromMetadata( *LHS_Instr->getMetadata(LLVMContext::MD_range)); - auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); - if (Satisfied_CR.contains(LHS_CR)) + if (LHS_CR.icmp(Pred, RHS_CR)) return ConstantInt::getTrue(RHS->getContext()); - auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), RHS_CR); - if (InversedSatisfied_CR.contains(LHS_CR)) + if (LHS_CR.icmp(CmpInst::getInversePredicate(Pred), RHS_CR)) return ConstantInt::getFalse(RHS->getContext()); } } diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index a481c23c3d0..4630c556262 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -9843,10 +9843,9 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( // This code is split out from isKnownPredicate because it is called from // within isLoopEntryGuardedByCond. - auto CheckRanges = - [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { - return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) - .contains(RangeLHS); + auto CheckRanges = [&](const ConstantRange &RangeLHS, + const ConstantRange &RangeRHS) { + return RangeLHS.icmp(Pred, RangeRHS); }; // The check at the top of the function catches the case where the values are @@ -11148,12 +11147,9 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": const APInt &ConstRHS = cast(RHS)->getAPInt(); - ConstantRange SatisfyingLHSRange = - ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); - // The antecedent implies the consequent if every value of `LHS` that // satisfies the antecedent also satisfies the consequent. - return SatisfyingLHSRange.contains(LHSRange); + return LHSRange.icmp(Pred, ConstRHS); } bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, diff --git a/lib/IR/ConstantRange.cpp b/lib/IR/ConstantRange.cpp index 4dbe1a1b902..b38599fa7d9 100644 --- a/lib/IR/ConstantRange.cpp +++ b/lib/IR/ConstantRange.cpp @@ -181,6 +181,11 @@ bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, return Success; } +bool ConstantRange::icmp(CmpInst::Predicate Pred, + const ConstantRange &Other) const { + return makeSatisfyingICmpRegion(Pred, Other).contains(*this); +} + /// Exact mul nuw region for single element RHS. static ConstantRange makeExactMulNUWRegion(const APInt &V) { unsigned BitWidth = V.getBitWidth(); diff --git a/lib/Transforms/IPO/AttributorAttributes.cpp b/lib/Transforms/IPO/AttributorAttributes.cpp index 867dd3118ce..33ed2b4423a 100644 --- a/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/lib/Transforms/IPO/AttributorAttributes.cpp @@ -7328,13 +7328,10 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { auto AllowedRegion = ConstantRange::makeAllowedICmpRegion(CmpI->getPredicate(), RHSAARange); - auto SatisfyingRegion = ConstantRange::makeSatisfyingICmpRegion( - CmpI->getPredicate(), RHSAARange); - if (AllowedRegion.intersectWith(LHSAARange).isEmptySet()) MustFalse = true; - if (SatisfyingRegion.contains(LHSAARange)) + if (LHSAARange.icmp(CmpI->getPredicate(), RHSAARange)) MustTrue = true; assert((!MustTrue || !MustFalse) && diff --git a/unittests/IR/ConstantRangeTest.cpp b/unittests/IR/ConstantRangeTest.cpp index 12362b9460f..f8816e4d43d 100644 --- a/unittests/IR/ConstantRangeTest.cpp +++ b/unittests/IR/ConstantRangeTest.cpp @@ -6,9 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/BitVector.h" -#include "llvm/ADT/SmallBitVector.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/Support/KnownBits.h" @@ -1509,6 +1510,52 @@ TEST(ConstantRange, MakeSatisfyingICmpRegion) { ConstantRange(APInt(8, 4), APInt(8, -128))); } +static bool icmp(CmpInst::Predicate Pred, const APInt &LHS, const APInt &RHS) { + switch (Pred) { + case CmpInst::Predicate::ICMP_EQ: + return LHS.eq(RHS); + case CmpInst::Predicate::ICMP_NE: + return LHS.ne(RHS); + case CmpInst::Predicate::ICMP_UGT: + return LHS.ugt(RHS); + case CmpInst::Predicate::ICMP_UGE: + return LHS.uge(RHS); + case CmpInst::Predicate::ICMP_ULT: + return LHS.ult(RHS); + case CmpInst::Predicate::ICMP_ULE: + return LHS.ule(RHS); + case CmpInst::Predicate::ICMP_SGT: + return LHS.sgt(RHS); + case CmpInst::Predicate::ICMP_SGE: + return LHS.sge(RHS); + case CmpInst::Predicate::ICMP_SLT: + return LHS.slt(RHS); + case CmpInst::Predicate::ICMP_SLE: + return LHS.sle(RHS); + default: + llvm_unreachable("Not an ICmp predicate!"); + } +} + +void ICmpTestImpl(CmpInst::Predicate Pred) { + unsigned Bits = 4; + EnumerateTwoConstantRanges( + Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) { + bool Exhaustive = true; + ForeachNumInConstantRange(CR1, [&](const APInt &N1) { + ForeachNumInConstantRange( + CR2, [&](const APInt &N2) { Exhaustive &= icmp(Pred, N1, N2); }); + }); + EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive); + }); +} + +TEST(ConstantRange, ICmp) { + for (auto Pred : seq(CmpInst::Predicate::FIRST_ICMP_PREDICATE, + 1 + CmpInst::Predicate::LAST_ICMP_PREDICATE)) + ICmpTestImpl((CmpInst::Predicate)Pred); +} + TEST(ConstantRange, MakeGuaranteedNoWrapRegion) { const int IntMin4Bits = 8; const int IntMax4Bits = 7;