From dd8e45c6d7404e408cbbe779f96c50b27b05b231 Mon Sep 17 00:00:00 2001 From: "Duncan P. N. Exon Smith" Date: Mon, 23 Jun 2014 20:40:45 +0000 Subject: [PATCH] Support: Extract ScaledNumbers::matchScale() llvm-svn: 211531 --- .../llvm/Analysis/BlockFrequencyInfoImpl.h | 50 ++------------- include/llvm/Support/ScaledNumber.h | 51 +++++++++++++++ unittests/Support/ScaledNumberTest.cpp | 64 +++++++++++++++++++ 3 files changed, 119 insertions(+), 46 deletions(-) diff --git a/include/llvm/Analysis/BlockFrequencyInfoImpl.h b/include/llvm/Analysis/BlockFrequencyInfoImpl.h index 054f26b1824..b462474adf4 100644 --- a/include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ b/include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -247,14 +247,10 @@ private: /// /// The value that compares smaller will lose precision, and possibly become /// \a isZero(). - UnsignedFloat matchExponents(UnsignedFloat X); - - /// \brief Increase exponent to match another float. - /// - /// Increases \c this to have an exponent matching \c X. May decrease the - /// exponent of \c X in the process, and \c this may possibly become \a - /// isZero(). - void increaseExponentToMatch(UnsignedFloat &X, int32_t ExponentDiff); + UnsignedFloat matchExponents(UnsignedFloat X) { + ScaledNumbers::matchScales(Digits, Exponent, X.Digits, X.Exponent); + return X; + } public: /// \brief Scale a large number accurately. @@ -403,44 +399,6 @@ IntT UnsignedFloat::toInt() const { return N; } -template -UnsignedFloat UnsignedFloat::matchExponents(UnsignedFloat X) { - if (isZero() || X.isZero() || Exponent == X.Exponent) - return X; - - int32_t Diff = int32_t(X.Exponent) - int32_t(Exponent); - if (Diff > 0) - increaseExponentToMatch(X, Diff); - else - X.increaseExponentToMatch(*this, -Diff); - return X; -} -template -void UnsignedFloat::increaseExponentToMatch(UnsignedFloat &X, - int32_t ExponentDiff) { - assert(ExponentDiff > 0); - if (ExponentDiff >= 2 * Width) { - *this = getZero(); - return; - } - - // Use up any leading zeros on X, and then shift this. - int32_t ShiftX = std::min(countLeadingZerosWidth(X.Digits), ExponentDiff); - assert(ShiftX < Width); - - int32_t ShiftThis = ExponentDiff - ShiftX; - if (ShiftThis >= Width) { - *this = getZero(); - return; - } - - X.Digits <<= ShiftX; - X.Exponent -= ShiftX; - Digits >>= ShiftThis; - Exponent += ShiftThis; - return; -} - template UnsignedFloat &UnsignedFloat:: operator+=(const UnsignedFloat &X) { diff --git a/include/llvm/Support/ScaledNumber.h b/include/llvm/Support/ScaledNumber.h index 240d5b64aa8..b1f6eadab84 100644 --- a/include/llvm/Support/ScaledNumber.h +++ b/include/llvm/Support/ScaledNumber.h @@ -24,6 +24,7 @@ #include "llvm/Support/MathExtras.h" +#include #include #include #include @@ -263,6 +264,56 @@ int compare(DigitsT LDigits, int16_t LScale, DigitsT RDigits, int16_t RScale) { return -compareImpl(RDigits, LDigits, LScale - RScale); } +/// \brief Match scales of two numbers. +/// +/// Given two scaled numbers, match up their scales. Change the digits and +/// scales in place. Shift the digits as necessary to form equivalent numbers, +/// losing precision only when necessary. +/// +/// If the output value of \c LDigits (\c RDigits) is \c 0, the output value of +/// \c LScale (\c RScale) is unspecified. If both \c LDigits and \c RDigits +/// are \c 0, the output value is one of \c LScale and \c RScale; which is +/// unspecified. +template +void matchScales(DigitsT &LDigits, int16_t &LScale, DigitsT &RDigits, + int16_t &RScale) { + static_assert(!std::numeric_limits::is_signed, "expected unsigned"); + + if (LScale < RScale) { + // Swap arguments. + matchScales(RDigits, RScale, LDigits, LScale); + return; + } + if (!LDigits || !RDigits || LScale == RScale) + return; + + // Now LScale > RScale. Get the difference. + int32_t ScaleDiff = int32_t(LScale) - RScale; + if (ScaleDiff >= 2 * getWidth()) { + // Don't bother shifting. RDigits will get zero-ed out anyway. + RDigits = 0; + return; + } + + // Shift LDigits left as much as possible, then shift RDigits right. + int32_t ShiftL = std::min(countLeadingZeros(LDigits), ScaleDiff); + assert(ShiftL < getWidth() && "can't shift more than width"); + + int32_t ShiftR = ScaleDiff - ShiftL; + if (ShiftR >= getWidth()) { + // Don't bother shifting. RDigits will get zero-ed out anyway. + RDigits = 0; + return; + } + + LDigits <<= ShiftL; + RDigits >>= ShiftR; + + LScale -= ShiftL; + RScale += ShiftR; + assert(LScale == RScale && "scales should match"); +} + } // end namespace ScaledNumbers } // end namespace llvm diff --git a/unittests/Support/ScaledNumberTest.cpp b/unittests/Support/ScaledNumberTest.cpp index 4a274d7e50e..550947b8ea3 100644 --- a/unittests/Support/ScaledNumberTest.cpp +++ b/unittests/Support/ScaledNumberTest.cpp @@ -322,4 +322,68 @@ TEST(ScaledNumberHelpersTest, compare) { EXPECT_EQ(-1, compare(UINT64_MAX, 0, UINT64_C(1), 64)); } +TEST(ScaledNumberHelpersTest, matchScales) { + typedef std::tuple Pair32; + typedef std::tuple Pair64; + +#define MATCH_SCALES(T, LDIn, LSIn, RDIn, RSIn, LDOut, RDOut, SOut) \ + do { \ + T LDx = LDIn; \ + T RDx = RDIn; \ + T LDy = LDOut; \ + T RDy = RDOut; \ + int16_t LSx = LSIn; \ + int16_t RSx = RSIn; \ + int16_t Sy = SOut; \ + \ + matchScales(LDx, LSx, RDx, RSx); \ + EXPECT_EQ(LDy, LDx); \ + EXPECT_EQ(RDy, RDx); \ + if (LDy) \ + EXPECT_EQ(Sy, LSx); \ + if (RDy) \ + EXPECT_EQ(Sy, RSx); \ + } while (false) + + MATCH_SCALES(uint32_t, 0, 0, 0, 0, 0, 0, 0); + MATCH_SCALES(uint32_t, 0, 50, 7, 1, 0, 7, 1); + MATCH_SCALES(uint32_t, UINT32_C(1) << 31, 1, 9, 0, UINT32_C(1) << 31, 4, 1); + MATCH_SCALES(uint32_t, UINT32_C(1) << 31, 2, 9, 0, UINT32_C(1) << 31, 2, 2); + MATCH_SCALES(uint32_t, UINT32_C(1) << 31, 3, 9, 0, UINT32_C(1) << 31, 1, 3); + MATCH_SCALES(uint32_t, UINT32_C(1) << 31, 4, 9, 0, UINT32_C(1) << 31, 0, 4); + MATCH_SCALES(uint32_t, UINT32_C(1) << 30, 4, 9, 0, UINT32_C(1) << 31, 1, 3); + MATCH_SCALES(uint32_t, UINT32_C(1) << 29, 4, 9, 0, UINT32_C(1) << 31, 2, 2); + MATCH_SCALES(uint32_t, UINT32_C(1) << 28, 4, 9, 0, UINT32_C(1) << 31, 4, 1); + MATCH_SCALES(uint32_t, UINT32_C(1) << 27, 4, 9, 0, UINT32_C(1) << 31, 9, 0); + MATCH_SCALES(uint32_t, 7, 1, 0, 50, 7, 0, 1); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 31, 1, 4, UINT32_C(1) << 31, 1); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 31, 2, 2, UINT32_C(1) << 31, 2); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 31, 3, 1, UINT32_C(1) << 31, 3); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 31, 4, 0, UINT32_C(1) << 31, 4); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 30, 4, 1, UINT32_C(1) << 31, 3); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 29, 4, 2, UINT32_C(1) << 31, 2); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 28, 4, 4, UINT32_C(1) << 31, 1); + MATCH_SCALES(uint32_t, 9, 0, UINT32_C(1) << 27, 4, 9, UINT32_C(1) << 31, 0); + + MATCH_SCALES(uint64_t, 0, 0, 0, 0, 0, 0, 0); + MATCH_SCALES(uint64_t, 0, 100, 7, 1, 0, 7, 1); + MATCH_SCALES(uint64_t, UINT64_C(1) << 63, 1, 9, 0, UINT64_C(1) << 63, 4, 1); + MATCH_SCALES(uint64_t, UINT64_C(1) << 63, 2, 9, 0, UINT64_C(1) << 63, 2, 2); + MATCH_SCALES(uint64_t, UINT64_C(1) << 63, 3, 9, 0, UINT64_C(1) << 63, 1, 3); + MATCH_SCALES(uint64_t, UINT64_C(1) << 63, 4, 9, 0, UINT64_C(1) << 63, 0, 4); + MATCH_SCALES(uint64_t, UINT64_C(1) << 62, 4, 9, 0, UINT64_C(1) << 63, 1, 3); + MATCH_SCALES(uint64_t, UINT64_C(1) << 61, 4, 9, 0, UINT64_C(1) << 63, 2, 2); + MATCH_SCALES(uint64_t, UINT64_C(1) << 60, 4, 9, 0, UINT64_C(1) << 63, 4, 1); + MATCH_SCALES(uint64_t, UINT64_C(1) << 59, 4, 9, 0, UINT64_C(1) << 63, 9, 0); + MATCH_SCALES(uint64_t, 7, 1, 0, 100, 7, 0, 1); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 63, 1, 4, UINT64_C(1) << 63, 1); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 63, 2, 2, UINT64_C(1) << 63, 2); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 63, 3, 1, UINT64_C(1) << 63, 3); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 63, 4, 0, UINT64_C(1) << 63, 4); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 62, 4, 1, UINT64_C(1) << 63, 3); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 61, 4, 2, UINT64_C(1) << 63, 2); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 60, 4, 4, UINT64_C(1) << 63, 1); + MATCH_SCALES(uint64_t, 9, 0, UINT64_C(1) << 59, 4, 9, UINT64_C(1) << 63, 0); +} + } // end namespace