From 54b4ece1ffc2bcd8ab36a7faffbdf23c1d990f0e Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Tue, 9 Mar 2021 12:35:16 -0800 Subject: [PATCH] [SCEV] Infer known bits from known sign bits This was suggested by lebedev.ri over on D96534. You'll note lack of tests. During review, we weren't actually able to find a case which exercises it, but both I and lebedev.ri feel it's a reasonable change, straight forward, and near free. Differential Revision: https://reviews.llvm.org/D97064 --- lib/Analysis/ScalarEvolution.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index b8d55e6eb68..7ffb3846738 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -5848,24 +5848,31 @@ ScalarEvolution::getRangeRef(const SCEV *S, KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); if (Known.getBitWidth() != BitWidth) Known = Known.zextOrTrunc(BitWidth); - // If Known does not result in full-set, intersect with it. - if (Known.getMinValue() != Known.getMaxValue() + 1) - ConservativeResult = ConservativeResult.intersectWith( - ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1), - RangeType); // ValueTracking may be able to compute a tighter result for the number of // sign bits than for the value of those sign bits. unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); - // If the pointer size is larger than the index size type, this can cause - // NS to be larger than BitWidth. So compensate for this. if (U->getType()->isPointerTy()) { + // If the pointer size is larger than the index size type, this can cause + // NS to be larger than BitWidth. So compensate for this. unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType()); int ptrIdxDiff = ptrSize - BitWidth; if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff) NS -= ptrIdxDiff; } + if (NS > 1) { + // If we know any of the sign bits, we know all of the sign bits. + if (!Known.Zero.getHiBits(NS).isNullValue()) + Known.Zero.setHighBits(NS); + if (!Known.One.getHiBits(NS).isNullValue()) + Known.One.setHighBits(NS); + } + + if (Known.getMinValue() != Known.getMaxValue() + 1) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1), + RangeType); if (NS > 1) ConservativeResult = ConservativeResult.intersectWith( ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),