1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-23 03:02:36 +01:00

Update MachineBranchProbabilityInfo::normalizeEdgeWeights to make sure there is no zero weight in the output, and also add a missing test for JumpThreading.

The test is for the patch in http://reviews.llvm.org/D10979 but was missing when committing that patch.

llvm-svn: 250240
This commit is contained in:
Cong Hou 2015-10-13 22:27:41 +00:00
parent 61ae84603b
commit b3e26fc217
2 changed files with 72 additions and 21 deletions

View File

@ -86,35 +86,43 @@ public:
const MachineBasicBlock *Dst) const;
// Normalize a list of weights by scaling them down so that the sum of them
// doesn't exceed UINT32_MAX. Return the scale.
// doesn't exceed UINT32_MAX.
template <class WeightListIter>
static uint32_t normalizeEdgeWeights(WeightListIter Begin,
WeightListIter End);
static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
};
template <class WeightListIter>
uint32_t
MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin,
WeightListIter End) {
void MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin,
WeightListIter End) {
// First we compute the sum with 64-bits of precision.
uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
// If Sum is zero, set all weights to 1.
if (Sum == 0)
std::fill(Begin, End, uint64_t(1));
if (Sum > UINT32_MAX) {
// Compute the scale necessary to cause the weights to fit, and re-sum with
// that scale applied.
assert(Sum / UINT32_MAX < UINT32_MAX &&
"The sum of weights exceeds UINT32_MAX^2!");
uint32_t Scale = Sum / UINT32_MAX + 1;
for (auto I = Begin; I != End; ++I)
*I /= Scale;
Sum = std::accumulate(Begin, End, uint64_t(0));
}
// If the computed sum fits in 32-bits, we're done.
if (Sum <= UINT32_MAX)
return 1;
// Otherwise, compute the scale necessary to cause the weights to fit, and
// re-sum with that scale applied.
assert((Sum / UINT32_MAX) < UINT32_MAX &&
"The sum of weights exceeds UINT32_MAX^2!");
uint32_t Scale = (Sum / UINT32_MAX) + 1;
for (auto I = Begin; I != End; ++I)
*I /= Scale;
return Scale;
// Eliminate zero weights.
auto ZeroWeightNum = std::count(Begin, End, 0u);
if (ZeroWeightNum > 0) {
// If all weights are zeros, replace them by 1.
if (Sum == 0)
std::fill(Begin, End, 1u);
else {
// Scale up non-zero weights and turn zero weights into ones.
uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
if (ScalingFactor > 1)
for (auto I = Begin; I != End; ++I)
*I *= ScalingFactor;
std::replace(Begin, End, 0u, 1u);
}
}
}
}

View File

@ -0,0 +1,43 @@
; RUN: opt -S -jump-threading %s | FileCheck %s
; Test if edge weights are properly updated after jump threading.
; CHECK: !2 = !{!"branch_weights", i32 22, i32 7}
define void @foo(i32 %n) !prof !0 {
entry:
%cmp = icmp sgt i32 %n, 10
br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1
if.then.1:
tail call void @a()
br label %if.cond
if.else.1:
tail call void @b()
br label %if.cond
if.cond:
%cmp1 = icmp sgt i32 %n, 5
br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2
if.then.2:
tail call void @c()
br label %if.end
if.else.2:
tail call void @d()
br label %if.end
if.end:
ret void
}
declare void @a()
declare void @b()
declare void @c()
declare void @d()
!0 = !{!"function_entry_count", i64 1}
!1 = !{!"branch_weights", i32 10, i32 5}
!2 = !{!"branch_weights", i32 10, i32 1}