diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index c4e639e54e4..980c75f3073 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -12591,6 +12591,32 @@ const SCEV* ScalarEvolution::computeMaxBackedgeTakenCount(const Loop *L) { } const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { + auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, + const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) { + // For now, limit to conditions that provide information about unknown + // expressions. + auto *LHSUnknown = dyn_cast(LHS); + if (!LHSUnknown) + return; + + // TODO: use information from more predicates. + switch (Predicate) { + case CmpInst::ICMP_ULT: { + if (!containsAddRecurrence(RHS)) { + const SCEV *Base = LHS; + auto I = RewriteMap.find(LHSUnknown->getValue()); + if (I != RewriteMap.end()) + Base = I->second; + + RewriteMap[LHSUnknown->getValue()] = + getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType()))); + } + break; + } + default: + break; + } + }; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -12613,26 +12639,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { auto Predicate = Cmp->getPredicate(); if (LoopEntryPredicate->getSuccessor(1) == Pair.second) Predicate = CmpInst::getInversePredicate(Predicate); - // TODO: use information from more predicates. - switch (Predicate) { - case CmpInst::ICMP_ULT: { - const SCEV *LHS = getSCEV(Cmp->getOperand(0)); - const SCEV *RHS = getSCEV(Cmp->getOperand(1)); - if (isa(LHS) && !isa(Cmp->getOperand(0)) && - !containsAddRecurrence(RHS)) { - const SCEV *Base = LHS; - auto I = RewriteMap.find(Cmp->getOperand(0)); - if (I != RewriteMap.end()) - Base = I->second; - - RewriteMap[Cmp->getOperand(0)] = - getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType()))); - } - break; - } - default: - break; - } + CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)), + getSCEV(Cmp->getOperand(1)), RewriteMap); } if (RewriteMap.empty())