From 14b186a8284f04e165d7d7b9a3b155d9374964dc Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 19 May 2016 22:55:46 +0000 Subject: [PATCH] [GuardWidening] Introduce range check merging Sequences of range checks expressed using guards, like guard((I - 2) u< L) guard((I - 1) u< L) guard((I + 0) u< L) guard((I + 1) u< L) guard((I + 2) u< L) can sometimes be combined into a smaller sequence: guard((I - 2) u< L AND (I + 2) u< L) if we can prove that (I - 2) u< L AND (I + 2) u< L implies all of checks expressed in the previous sequence. This change teaches GuardWidening to do this kind of merging when feasible. llvm-svn: 270151 --- lib/Transforms/Scalar/GuardWidening.cpp | 244 ++++++++++++++++++ .../GuardWidening/range-check-merging.ll | 197 ++++++++++++++ 2 files changed, 441 insertions(+) create mode 100644 test/Transforms/GuardWidening/range-check-merging.ll diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index 5ac4374038e..24be4f508cc 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -130,6 +130,55 @@ class GuardWideningImpl { bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, Value *&Result); + /// Represents a range check of the form \c Base + \c Offset u< \c Length, + /// with the constraint that \c Length is not negative. \c CheckInst is the + /// pre-existing instruction in the IR that computes the result of this range + /// check. + struct RangeCheck { + Value *Base; + ConstantInt *Offset; + Value *Length; + ICmpInst *CheckInst; + + RangeCheck() {} + + explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length, + ICmpInst *CheckInst) + : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {} + + void print(raw_ostream &OS, bool PrintTypes = false) { + OS << "Base: "; + Base->printAsOperand(OS, PrintTypes); + OS << " Offset: "; + Offset->printAsOperand(OS, PrintTypes); + OS << " Length: "; + Length->printAsOperand(OS, PrintTypes); + } + + LLVM_DUMP_METHOD void dump() { + print(dbgs()); + dbgs() << "\n"; + } + }; + + /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and + /// append them to \p Checks. Returns true on success, may clobber \c Checks + /// on failure. + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl &Checks) { + SmallPtrSet Visited; + return parseRangeChecks(CheckCond, Checks, Visited); + } + + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl &Checks, + SmallPtrSetImpl &Visited); + + /// Combine the checks in \p Checks into a smaller set of checks and append + /// them into \p CombinedChecks. Return true on success (i.e. all of checks + /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks + /// and \p CombinedChecks on success and on failure. + bool combineRangeChecks(SmallVectorImpl &Checks, + SmallVectorImpl &CombinedChecks); + /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of /// computing only one of the two expressions? bool isWideningCondProfitable(Value *Cond0, Value *Cond1) { @@ -386,6 +435,27 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, } } + { + SmallVector Checks, CombinedChecks; + if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + combineRangeChecks(Checks, CombinedChecks)) { + if (InsertPt) { + Result = nullptr; + for (auto &RC : CombinedChecks) { + makeAvailableAt(RC.CheckInst, InsertPt); + if (Result) + Result = + BinaryOperator::CreateAnd(RC.CheckInst, Result, "", InsertPt); + else + Result = RC.CheckInst; + } + + Result->setName("wide.chk"); + } + return true; + } + } + // Base case -- just logical-and the two conditions together. if (InsertPt) { @@ -399,6 +469,180 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, return false; } +bool GuardWideningImpl::parseRangeChecks( + Value *CheckCond, SmallVectorImpl &Checks, + SmallPtrSetImpl &Visited) { + if (!Visited.insert(CheckCond).second) + return true; + + using namespace llvm::PatternMatch; + + { + Value *AndLHS, *AndRHS; + if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS)))) + return parseRangeChecks(AndLHS, Checks) && + parseRangeChecks(AndRHS, Checks); + } + + auto *IC = dyn_cast(CheckCond); + if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() || + (IC->getPredicate() != ICmpInst::ICMP_ULT && + IC->getPredicate() != ICmpInst::ICMP_UGT)) + return false; + + Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); + if (IC->getPredicate() == ICmpInst::ICMP_UGT) + std::swap(CmpLHS, CmpRHS); + + auto &DL = IC->getModule()->getDataLayout(); + + GuardWideningImpl::RangeCheck Check; + Check.Base = CmpLHS; + Check.Offset = + cast(ConstantInt::getNullValue(CmpRHS->getType())); + Check.Length = CmpRHS; + Check.CheckInst = IC; + + if (!isKnownNonNegative(Check.Length, DL)) + return false; + + // What we have in \c Check now is a correct interpretation of \p CheckCond. + // Try to see if we can move some constant offsets into the \c Offset field. + + bool Changed; + + do { + Value *OpLHS; + ConstantInt *OpRHS; + Changed = false; + +#ifndef NDEBUG + auto *BaseInst = dyn_cast(Check.Base); + assert((!BaseInst || DT.isReachableFromEntry(BaseInst->getParent())) && + "Unreachable instruction?"); +#endif + + if (match(Check.Base, m_Add(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + Check.Base = OpLHS; + Check.Offset = + ConstantInt::get(Check.Offset->getContext(), + Check.Offset->getValue() + OpRHS->getValue()); + Changed = true; + } else if (match(Check.Base, m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + unsigned BitWidth = OpLHS->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(OpLHS, KnownZero, KnownOne, DL); + if ((OpRHS->getValue() & KnownZero) == OpRHS->getValue()) { + Check.Base = OpLHS; + Check.Offset = + ConstantInt::get(Check.Offset->getContext(), + Check.Offset->getValue() + OpRHS->getValue()); + Changed = true; + } + } + } while (Changed); + + Checks.push_back(Check); + return true; +} + +bool GuardWideningImpl::combineRangeChecks( + SmallVectorImpl &Checks, + SmallVectorImpl &RangeChecksOut) { + unsigned OldCount = Checks.size(); + while (!Checks.empty()) { + Value *Base = Checks[0].Base; + Value *Length = Checks[0].Length; + auto ChecksStart = + remove_if(Checks, [&](GuardWideningImpl::RangeCheck &RC) { + return RC.Base == Base && RC.Length == Length; + }); + + unsigned CheckCount = std::distance(ChecksStart, Checks.end()); + assert(CheckCount != 0 && "We know we have at least one!"); + + if (CheckCount < 3) { + RangeChecksOut.insert(RangeChecksOut.end(), ChecksStart, Checks.end()); + Checks.erase(ChecksStart, Checks.end()); + continue; + } + + // CheckCount will typically be 3 here, but so far there has been no need to + // hard-code that fact. + + std::sort(ChecksStart, Checks.end(), + [&](GuardWideningImpl::RangeCheck &LHS, + GuardWideningImpl::RangeCheck &RHS) { + return LHS.Offset->getValue().slt(RHS.Offset->getValue()); + }); + + // Note: std::sort should not invalidate the ChecksStart iterator. + + ConstantInt *MinOffset = ChecksStart->Offset, + *MaxOffset = Checks.back().Offset; + + unsigned BitWidth = MaxOffset->getValue().getBitWidth(); + if ((MaxOffset->getValue() - MinOffset->getValue()) + .ugt(APInt::getSignedMinValue(BitWidth))) + return false; + + APInt MaxDiff = MaxOffset->getValue() - MinOffset->getValue(); + APInt HighOffset = MaxOffset->getValue(); + auto OffsetOK = [&](GuardWideningImpl::RangeCheck &RC) { + return (HighOffset - RC.Offset->getValue()).ult(MaxDiff); + }; + + if (MaxDiff.isMinValue() || + !std::all_of(std::next(ChecksStart), Checks.end(), OffsetOK)) + return false; + + // We have a series of f+1 checks as: + // + // I+k_0 u< L ... Chk_0 + // I_k_1 u< L ... Chk_1 + // ... + // I_k_f u< L ... Chk_(f+1) + // + // with forall i in [0,f): k_f-k_i u< k_f-k_0 ... Precond_0 + // k_f-k_0 u< INT_MIN+k_f ... Precond_1 + // k_f != k_0 ... Precond_2 + // + // Claim: + // Chk_0 AND Chk_(f+1) implies all the other checks + // + // Informal proof sketch: + // + // We will show that the integer range [I+k_0,I+k_f] does not unsigned-wrap + // (i.e. going from I+k_0 to I+k_f does not cross the -1,0 boundary) and + // thus I+k_f is the greatest unsigned value in that range. + // + // This combined with Ckh_(f+1) shows that everything in that range is u< L. + // Via Precond_0 we know that all of the indices in Chk_0 through Chk_(f+1) + // lie in [I+k_0,I+k_f], this proving our claim. + // + // To see that [I+k_0,I+k_f] is not a wrapping range, note that there are + // two possibilities: I+k_0 u< I+k_f or I+k_0 >u I+k_f (they can't be equal + // since k_0 != k_f). In the former case, [I+k_0,I+k_f] is not a wrapping + // range by definition, and the latter case is impossible: + // + // 0-----I+k_f---I+k_0----L---INT_MAX,INT_MIN------------------(-1) + // xxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + // + // For Chk_0 to succeed, we'd have to have k_f-k_0 (the range highlighted + // with 'x' above) to be at least >u INT_MIN. + + RangeChecksOut.emplace_back(Base, MinOffset, Length, + ChecksStart->CheckInst); + RangeChecksOut.emplace_back(Base, MaxOffset, Length, + Checks.back().CheckInst); + + Checks.erase(ChecksStart, Checks.end()); + } + + assert(RangeChecksOut.size() <= OldCount && "We pessimized!"); + return RangeChecksOut.size() != OldCount; +} + PreservedAnalyses GuardWideningPass::run(Function &F, AnalysisManager &AM) { auto &DT = AM.getResult(F); diff --git a/test/Transforms/GuardWidening/range-check-merging.ll b/test/Transforms/GuardWidening/range-check-merging.ll new file mode 100644 index 00000000000..304eebc9b86 --- /dev/null +++ b/test/Transforms/GuardWidening/range-check-merging.ll @@ -0,0 +1,197 @@ +; RUN: opt -S -guard-widening < %s | FileCheck %s + +declare void @llvm.experimental.guard(i1,...) + +define void @f_0(i32 %x, i32* %length_buf) { +; CHECK-LABEL: @f_0( +; CHECK-NOT: @llvm.experimental.guard +; CHECK: %wide.chk2 = and i1 %chk3, %chk0 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, 1 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = add i32 %x, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +define void @f_1(i32 %x, i32* %length_buf) { +; CHECK-LABEL: @f_1( +; CHECK-NOT: llvm.experimental.guard +; CHECK: %wide.chk2 = and i1 %chk3, %chk0 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, 1 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = add i32 %x.inc1, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x.inc2, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +define void @f_2(i32 %a, i32* %length_buf) { +; CHECK-LABEL: @f_2( +; CHECK-NOT: llvm.experimental.guard +; CHECK: %wide.chk2 = and i1 %chk3, %chk0 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %x = and i32 %a, 4294967040 ;; 4294967040 == 0xffffff00 + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = or i32 %x, 1 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = or i32 %x, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = or i32 %x, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +define void @f_3(i32 %a, i32* %length_buf) { +; CHECK-LABEL: @f_3( +; CHECK-NOT: llvm.experimental.guard +; CHECK: %wide.chk2 = and i1 %chk3, %chk0 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %x = and i32 %a, 4294967040 ;; 4294967040 == 0xffffff00 + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, 1 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = or i32 %x.inc1, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x.inc2, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +define void @f_4(i32 %x, i32* %length_buf) { +; CHECK-LABEL: @f_4( +; CHECK-NOT: llvm.experimental.guard + +; Note: we NOT guarding on "and i1 %chk3, %chk0", that would be incorrect. +; CHECK: %wide.chk2 = and i1 %chk3, %chk1 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, -1024 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = add i32 %x, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +define void @f_5(i32 %x, i32* %length_buf) { +; CHECK-LABEL: @f_5( +; CHECK-NOT: llvm.experimental.guard +; CHECK: %wide.chk2 = and i1 %chk1, %chk2 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +; CHECK: ret void +entry: + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, 1 + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = add i32 %x.inc1, -200 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x.inc2, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + + +; Negative test: we can't merge these checks into +; +; (%x + -2147483647) u< L && (%x + 3) u< L +; +; because if %length == INT_MAX and %x == -3 then +; +; (%x + -2147483647) == i32 2147483646 u< L (L is 2147483647) +; (%x + 3) == 0 u< L +; +; But (%x + 2) == -1 is not u< L +; +define void @f_6(i32 %x, i32* %length_buf) { +; CHECK-LABEL: @f_6( +; CHECK-NOT: llvm.experimental.guard +; CHECK: %wide.chk = and i1 %chk0, %chk1 +; CHECK: %wide.chk1 = and i1 %wide.chk, %chk2 +; CHECK: %wide.chk2 = and i1 %wide.chk1, %chk3 +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 %wide.chk2) [ "deopt"() ] +entry: + %length = load i32, i32* %length_buf, !range !0 + %chk0 = icmp ult i32 %x, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk0) [ "deopt"() ] + + %x.inc1 = add i32 %x, -2147483647 ;; -2147483647 == (i32 INT_MIN)+1 == -(i32 INT_MAX) + %chk1 = icmp ult i32 %x.inc1, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk1) [ "deopt"() ] + + %x.inc2 = add i32 %x, 2 + %chk2 = icmp ult i32 %x.inc2, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk2) [ "deopt"() ] + + %x.inc3 = add i32 %x, 3 + %chk3 = icmp ult i32 %x.inc3, %length + call void(i1, ...) @llvm.experimental.guard(i1 %chk3) [ "deopt"() ] + ret void +} + +!0 = !{i32 0, i32 2147483648}