diff --git a/include/llvm/ProfileData/InstrProf.h b/include/llvm/ProfileData/InstrProf.h index 13f6c70b3e8..95648511910 100644 --- a/include/llvm/ProfileData/InstrProf.h +++ b/include/llvm/ProfileData/InstrProf.h @@ -428,19 +428,22 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other) { if (Counts.size() != Other.Counts.size()) return instrprof_error::count_mismatch; + instrprof_error Result = instrprof_error::success; + for (size_t I = 0, E = Other.Counts.size(); I < E; ++I) { - if (Counts[I] + Other.Counts[I] < Counts[I]) - return instrprof_error::counter_overflow; - Counts[I] += Other.Counts[I]; + bool ResultOverflowed; + Counts[I] = SaturatingAdd(Counts[I], Other.Counts[I], ResultOverflowed); + if (ResultOverflowed) + Result = instrprof_error::counter_overflow; } for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) { - instrprof_error result = mergeValueProfData(Kind, Other); - if (result != instrprof_error::success) - return result; + instrprof_error MergeValueResult = mergeValueProfData(Kind, Other); + if (MergeValueResult != instrprof_error::success) + Result = MergeValueResult; } - return instrprof_error::success; + return Result; } inline support::endianness getHostEndianness() { diff --git a/lib/ProfileData/InstrProfWriter.cpp b/lib/ProfileData/InstrProfWriter.cpp index f9cc2afe3da..78bec012eeb 100644 --- a/lib/ProfileData/InstrProfWriter.cpp +++ b/lib/ProfileData/InstrProfWriter.cpp @@ -107,22 +107,23 @@ std::error_code InstrProfWriter::addRecord(InstrProfRecord &&I) { std::tie(Where, NewFunc) = ProfileDataMap.insert(std::make_pair(I.Hash, InstrProfRecord())); InstrProfRecord &Dest = Where->second; + + instrprof_error Result; if (NewFunc) { // We've never seen a function with this name and hash, add it. Dest = std::move(I); + Result = instrprof_error::success; } else { // We're updating a function we've seen before. - instrprof_error MergeResult = Dest.merge(I); - if (MergeResult != instrprof_error::success) { - return MergeResult; - } + Result = Dest.merge(I); } // We keep track of the max function count as we go for simplicity. + // Update this statistic no matter the result of the merge. if (Dest.Counts[0] > MaxFunctionCount) MaxFunctionCount = Dest.Counts[0]; - return instrprof_error::success; + return Result; } std::pair InstrProfWriter::writeImpl(raw_ostream &OS) { diff --git a/test/tools/llvm-profdata/overflow.proftext b/test/tools/llvm-profdata/overflow.proftext index cbf3bf16182..b8401ffd84d 100644 --- a/test/tools/llvm-profdata/overflow.proftext +++ b/test/tools/llvm-profdata/overflow.proftext @@ -1,12 +1,20 @@ -# RUN: llvm-profdata merge %s -o %t.out 2>&1 | FileCheck %s -# CHECK: overflow.proftext: overflow: Counter overflow +# RUN: llvm-profdata merge %s -o %t.out 2>&1 | FileCheck %s --check-prefix=MERGE +# RUN: llvm-profdata show %t.out | FileCheck %s --check-prefix=SHOW +# MERGE: overflow.proftext: overflow: Counter overflow +# SHOW: Total functions: 1 +# SHOW: Maximum function count: 18446744073709551615 +# SHOW: Maximum internal block count: 18446744073709551615 overflow 1 -1 +3 +18446744073709551615 9223372036854775808 +18446744073709551615 overflow 1 -1 +3 9223372036854775808 +9223372036854775808 +0 diff --git a/unittests/ProfileData/InstrProfTest.cpp b/unittests/ProfileData/InstrProfTest.cpp index 2f3adb65a0e..635a5431a51 100644 --- a/unittests/ProfileData/InstrProfTest.cpp +++ b/unittests/ProfileData/InstrProfTest.cpp @@ -354,7 +354,7 @@ TEST_F(InstrProfTest, get_icall_data_merge1_saturation) { const uint64_t Max = std::numeric_limits::max(); InstrProfRecord Record1("caller", 0x1234, {1}); - InstrProfRecord Record2("caller", 0x1234, {1}); + InstrProfRecord Record2("caller", 0x1234, {Max}); InstrProfRecord Record3("callee1", 0x1235, {3, 4}); Record1.reserveSites(IPVK_IndirectCallTarget, 1); @@ -375,6 +375,9 @@ TEST_F(InstrProfTest, get_icall_data_merge1_saturation) { // Verify saturation of counts. ErrorOr R = Reader->getInstrProfRecord("caller", 0x1234); ASSERT_TRUE(NoError(R.getError())); + + ASSERT_EQ(Max, R.get().Counts[0]); + ASSERT_EQ(1U, R.get().getNumValueSites(IPVK_IndirectCallTarget)); std::unique_ptr VD = R.get().getValueForSite(IPVK_IndirectCallTarget, 0);