From 8bb5944166c814ce4b5506da4ba84470c490872d Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Fri, 16 Apr 2021 08:39:22 -0400 Subject: [PATCH] [ValueTracking] don't recursively compute known bits using multiple llvm.assumes This is an alternative to D99759 to avoid the compile-time explosion seen in: https://llvm.org/PR49785 Another potential solution would make the exclusion logic stronger to avoid blowing up, but note that we reduced the complexity of the exclusion mechanism in D16204 because it was too costly. So I'm questioning the need for recursion/exclusion entirely - what is the optimization value vs. cost of recursively computing known bits based on assumptions? This was built into the implementation from the start with 60db058, and we have kept adding code/cost to deal with that capability. By clearing the query's AssumptionCache inside computeKnownBitsFromAssume(), this patch retains all existing assume functionality except refining known bits based on even more assumptions. We have 1 regression test that shows a difference in optimization power. Differential Revision: https://reviews.llvm.org/D100573 --- lib/Analysis/ValueTracking.cpp | 88 ++++++++++----------------- test/Transforms/InstCombine/assume.ll | 13 ++-- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index d8963c6ee83..46aa84f72aa 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -107,40 +107,13 @@ struct Query { // provide it currently. OptimizationRemarkEmitter *ORE; - /// Set of assumptions that should be excluded from further queries. - /// This is because of the potential for mutual recursion to cause - /// computeKnownBits to repeatedly visit the same assume intrinsic. The - /// classic case of this is assume(x = y), which will attempt to determine - /// bits in x from bits in y, which will attempt to determine bits in y from - /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call - /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo - /// (all of which can call computeKnownBits), and so on. - std::array Excluded; - /// If true, it is safe to use metadata during simplification. InstrInfoQuery IIQ; - unsigned NumExcluded = 0; - Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo, OptimizationRemarkEmitter *ORE = nullptr) : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} - - Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ), - NumExcluded(Q.NumExcluded) { - Excluded = Q.Excluded; - Excluded[NumExcluded++] = NewExcl; - assert(NumExcluded <= Excluded.size()); - } - - bool isExcluded(const Value *Value) const { - if (NumExcluded == 0) - return false; - auto End = Excluded.begin() + NumExcluded; - return std::find(Excluded.begin(), End, Value) != End; - } }; } // end anonymous namespace @@ -632,8 +605,6 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { CallInst *I = cast(AssumeVH); assert(I->getFunction() == Q.CxtI->getFunction() && "Got assumption for the wrong function!"); - if (Q.isExcluded(I)) - continue; // Warning: This loop can end up being somewhat performance sensitive. // We're running this loop for once for each value queried resulting in a @@ -681,8 +652,6 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, CallInst *I = cast(AssumeVH); assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && "Got assumption for the wrong function!"); - if (Q.isExcluded(I)) - continue; // Warning: This loop can end up being somewhat performance sensitive. // We're running this loop for once for each value queried resulting in a @@ -713,6 +682,15 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (!Cmp) continue; + // We are attempting to compute known bits for the operands of an assume. + // Do not try to use other assumptions for those recursive calls because + // that can lead to mutual recursion and a compile-time explosion. + // An example of the mutual recursion: computeKnownBits can call + // isKnownNonZero which calls computeKnownBitsFromAssume (this function) + // and so on. + Query QueryNoAC = Q; + QueryNoAC.AC = nullptr; + // Note that ptrtoint may change the bitwidth. Value *A, *B; auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); @@ -727,7 +705,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); Known.Zero |= RHSKnown.Zero; Known.One |= RHSKnown.One; // assume(v & b = a) @@ -735,9 +713,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits MaskKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. @@ -748,9 +726,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits MaskKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. @@ -761,9 +739,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. @@ -774,9 +752,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. @@ -787,9 +765,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, @@ -803,9 +781,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are @@ -819,7 +797,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. @@ -832,7 +810,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. RHSKnown.One.lshrInPlace(C); @@ -844,7 +822,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. Known.Zero |= RHSKnown.Zero << C; @@ -854,7 +832,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. Known.Zero |= RHSKnown.One << C; @@ -866,7 +844,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -879,7 +857,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -892,7 +870,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -905,7 +883,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isZero() || RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -918,7 +896,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // Whatever high bits in c are zero are known to be zero. Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); @@ -929,7 +907,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // If the RHS is known zero, then this assumption must be wrong (nothing // is unsigned less than zero). Signal a conflict and get out of here. @@ -941,7 +919,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). - if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) + if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC)) Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); else Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); diff --git a/test/Transforms/InstCombine/assume.ll b/test/Transforms/InstCombine/assume.ll index d25f1d6adf1..403d82c4e7a 100644 --- a/test/Transforms/InstCombine/assume.ll +++ b/test/Transforms/InstCombine/assume.ll @@ -175,15 +175,20 @@ entry: ret i32 %and1 } -define i32 @bar4(i32 %a, i32 %b) { -; CHECK-LABEL: @bar4( +; If we allow recursive known bits queries based on +; assumptions, we could do better here: +; a == b and a & 7 == 1, so b & 7 == 1, so b & 3 == 1, so return 1. + +define i32 @known_bits_recursion_via_assumes(i32 %a, i32 %b) { +; CHECK-LABEL: @known_bits_recursion_via_assumes( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[AND1:%.*]] = and i32 [[B:%.*]], 3 ; CHECK-NEXT: [[AND:%.*]] = and i32 [[A:%.*]], 7 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 1 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B:%.*]] +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B]] ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) -; CHECK-NEXT: ret i32 1 +; CHECK-NEXT: ret i32 [[AND1]] ; entry: %and1 = and i32 %b, 3