diff --git a/include/llvm/Analysis/AssumptionCache.h b/include/llvm/Analysis/AssumptionCache.h index b9ffd9a6c53..c4602d3449c 100644 --- a/include/llvm/Analysis/AssumptionCache.h +++ b/include/llvm/Analysis/AssumptionCache.h @@ -18,9 +18,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -46,22 +44,6 @@ public: /// llvm.assume. enum : unsigned { ExprResultIdx = std::numeric_limits::max() }; - /// Callback handle to ensure we do not have dangling pointers to llvm.assume - /// calls in our cache. - class AssumeHandle final : public CallbackVH { - AssumptionCache *AC; - - /// Make sure llvm.assume calls that are deleted are removed from the cache. - void deleted() override; - - public: - AssumeHandle(Value *V, AssumptionCache *AC = nullptr) - : CallbackVH(V), AC(AC) {} - - operator Value *() const { return getValPtr(); } - CallInst *getAssumeCI() const { return cast(getValPtr()); } - }; - struct ResultElem { WeakVH Assume; @@ -77,9 +59,9 @@ private: /// We track this to lazily populate our assumptions. Function &F; - /// Set of value handles for calls of the \@llvm.assume intrinsic. - using AssumeHandleSet = DenseSet>; - AssumeHandleSet AssumeHandles; + /// Vector of weak value handles to calls of the \@llvm.assume + /// intrinsic. + SmallVector AssumeHandles; class AffectedValueCallbackVH final : public CallbackVH { AssumptionCache *AC; @@ -155,7 +137,13 @@ public: /// Access the list of assumption handles currently tracked for this /// function. - AssumeHandleSet &assumptions() { + /// + /// Note that these produce weak handles that may be null. The caller must + /// handle that case. + /// FIXME: We should replace this with pointee_iterator> + /// when we can write that to filter out the null values. Then caller code + /// will become simpler. + MutableArrayRef assumptions() { if (!Scanned) scanFunction(); return AssumeHandles; diff --git a/lib/Analysis/AssumptionCache.cpp b/lib/Analysis/AssumptionCache.cpp index e2a31d6618c..70053fdf8d3 100644 --- a/lib/Analysis/AssumptionCache.cpp +++ b/lib/Analysis/AssumptionCache.cpp @@ -163,12 +163,7 @@ void AssumptionCache::unregisterAssumption(CallInst *CI) { AffectedValues.erase(AVI); } - AssumeHandles.erase({CI, this}); -} - -void AssumptionCache::AssumeHandle::deleted() { - AC->AssumeHandles.erase(*this); - // 'this' now dangles! + erase_value(AssumeHandles, CI); } void AssumptionCache::AffectedValueCallbackVH::deleted() { @@ -209,14 +204,14 @@ void AssumptionCache::scanFunction() { for (BasicBlock &B : F) for (Instruction &II : B) if (match(&II, m_Intrinsic())) - AssumeHandles.insert({&II, this}); + AssumeHandles.push_back({&II, ExprResultIdx}); // Mark the scan as complete. Scanned = true; // Update affected values. - for (auto &AssumeVH : AssumeHandles) - updateAffectedValues(AssumeVH.getAssumeCI()); + for (auto &A : AssumeHandles) + updateAffectedValues(cast(A)); } void AssumptionCache::registerAssumption(CallInst *CI) { @@ -228,7 +223,7 @@ void AssumptionCache::registerAssumption(CallInst *CI) { if (!Scanned) return; - AssumeHandles.insert({CI, this}); + AssumeHandles.push_back({CI, ExprResultIdx}); #ifndef NDEBUG assert(CI->getParent() && @@ -236,11 +231,20 @@ void AssumptionCache::registerAssumption(CallInst *CI) { assert(&F == CI->getParent()->getParent() && "Cannot register @llvm.assume call not in this function"); - for (auto &AssumeVH : AssumeHandles) { - assert(&F == AssumeVH.getAssumeCI()->getCaller() && + // We expect the number of assumptions to be small, so in an asserts build + // check that we don't accumulate duplicates and that all assumptions point + // to the same function. + SmallPtrSet AssumptionSet; + for (auto &VH : AssumeHandles) { + if (!VH) + continue; + + assert(&F == cast(VH)->getParent()->getParent() && "Cached assumption not inside this function!"); - assert(match(AssumeVH.getAssumeCI(), m_Intrinsic()) && + assert(match(cast(VH), m_Intrinsic()) && "Cached something other than a call to @llvm.assume!"); + assert(AssumptionSet.insert(VH).second && + "Cache contains multiple copies of a call!"); } #endif @@ -254,8 +258,9 @@ PreservedAnalyses AssumptionPrinterPass::run(Function &F, AssumptionCache &AC = AM.getResult(F); OS << "Cached assumptions for function: " << F.getName() << "\n"; - for (auto &AssumeVH : AC.assumptions()) - OS << " " << *AssumeVH.getAssumeCI()->getArgOperand(0) << "\n"; + for (auto &VH : AC.assumptions()) + if (VH) + OS << " " << *cast(VH)->getArgOperand(0) << "\n"; return PreservedAnalyses::all(); } @@ -301,8 +306,9 @@ void AssumptionCacheTracker::verifyAnalysis() const { SmallPtrSet AssumptionSet; for (const auto &I : AssumptionCaches) { - for (auto &AssumeVH : I.second->assumptions()) - AssumptionSet.insert(AssumeVH.getAssumeCI()); + for (auto &VH : I.second->assumptions()) + if (VH) + AssumptionSet.insert(cast(VH)); for (const BasicBlock &B : cast(*I.first)) for (const Instruction &II : B) diff --git a/lib/Analysis/CodeMetrics.cpp b/lib/Analysis/CodeMetrics.cpp index 846e1eb3551..8c8e2ee6627 100644 --- a/lib/Analysis/CodeMetrics.cpp +++ b/lib/Analysis/CodeMetrics.cpp @@ -74,7 +74,9 @@ void CodeMetrics::collectEphemeralValues( SmallVector Worklist; for (auto &AssumeVH : AC->assumptions()) { - Instruction *I = AssumeVH.getAssumeCI(); + if (!AssumeVH) + continue; + Instruction *I = cast(AssumeVH); // Filter out call sites outside of the loop so we don't do a function's // worth of work for each of its loops (and, in the common case, ephemeral @@ -96,7 +98,9 @@ void CodeMetrics::collectEphemeralValues( SmallVector Worklist; for (auto &AssumeVH : AC->assumptions()) { - Instruction *I = AssumeVH.getAssumeCI(); + if (!AssumeVH) + continue; + Instruction *I = cast(AssumeVH); assert(I->getParent()->getParent() == F && "Found assumption for the wrong function!"); diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 1243ad4bfc6..13b07d74e52 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1704,9 +1704,9 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } - + // For a negative step, we can extend the operands iff doing so only - // traverses values in the range zext([0,UINT_MAX]). + // traverses values in the range zext([0,UINT_MAX]). if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRangeMin(Step)); @@ -9927,7 +9927,9 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { - auto *CI = AssumeVH.getAssumeCI(); + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); if (!DT.dominates(CI, Latch->getTerminator())) continue; @@ -10074,7 +10076,9 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { - auto *CI = AssumeVH.getAssumeCI(); + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); if (!DT.dominates(CI, BB)) continue; @@ -13354,7 +13358,9 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { // Also collect information from assumptions dominating the loop. for (auto &AssumeVH : AC.assumptions()) { - auto *AssumeI = AssumeVH.getAssumeCI(); + if (!AssumeVH) + continue; + auto *AssumeI = cast(AssumeVH); auto *Cmp = dyn_cast(AssumeI->getOperand(0)); if (!Cmp || !DT.dominates(AssumeI, L->getHeader())) continue; diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 469060a9314..bccf94fc217 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -331,11 +331,12 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, DT = DT_; bool Changed = false; - for (auto &AssumeVH : AC.assumptions()) { - CallInst *Call = AssumeVH.getAssumeCI(); - for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) - Changed |= processAssumption(Call, Idx); - } + for (auto &AssumeVH : AC.assumptions()) + if (AssumeVH) { + CallInst *Call = cast(AssumeVH); + for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) + Changed |= processAssumption(Call, Idx); + } return Changed; } diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index ec316dec936..05189a1b6cf 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -1780,8 +1780,10 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, const Function &NewFunc, AssumptionCache *AC) { - for (auto &AssumeVH : AC->assumptions()) { - auto *I = AssumeVH.getAssumeCI(); + for (auto AssumeVH : AC->assumptions()) { + auto *I = dyn_cast_or_null(AssumeVH); + if (!I) + continue; // There shouldn't be any llvm.assume intrinsics in the new function. if (I->getFunction() != &OldFunc) diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index fe69382b4ad..71b1926b92e 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -530,11 +530,10 @@ void PredicateInfoBuilder::buildPredicateInfo() { processSwitch(SI, BranchBB, OpsToRename); } } - for (auto &AssumeVH : AC.assumptions()) { - CallInst *AssumeCI = AssumeVH.getAssumeCI(); - if (DT.isReachableFromEntry(AssumeCI->getParent())) - processAssume(cast(AssumeCI), AssumeCI->getParent(), - OpsToRename); + for (auto &Assume : AC.assumptions()) { + if (auto *II = dyn_cast_or_null(Assume)) + if (DT.isReachableFromEntry(II->getParent())) + processAssume(II, II->getParent(), OpsToRename); } // Now rename all our operations. renameUses(OpsToRename); diff --git a/test/Analysis/AssumptionCache/basic.ll b/test/Analysis/AssumptionCache/basic.ll index 161fe10ed04..bd4e7b6449f 100644 --- a/test/Analysis/AssumptionCache/basic.ll +++ b/test/Analysis/AssumptionCache/basic.ll @@ -6,9 +6,9 @@ declare void @llvm.assume(i1) define void @test1(i32 %a) { ; CHECK-LABEL: Cached assumptions for function: test1 -; CHECK-DAG: icmp ne i32 %{{.*}}, 0 -; CHECK-DAG: icmp slt i32 %{{.*}}, 0 -; CHECK-DAG: icmp sgt i32 %{{.*}}, 0 +; CHECK-NEXT: icmp ne i32 %{{.*}}, 0 +; CHECK-NEXT: icmp slt i32 %{{.*}}, 0 +; CHECK-NEXT: icmp sgt i32 %{{.*}}, 0 entry: %cond1 = icmp ne i32 %a, 0