1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-19 11:02:59 +02:00

[SCEV] Do not cache comparison result upon reached max depth as "equivalence". PR48725

We use `EquivalenceClasses` to cache the notion that two SCEVs are equivalent,
so save time in situation when `A` is equivalent to `B` and `B` is equivalent to `C`,
making check "if `A` is equivalent to `C`?" cheaper.

We also return `0` in the comparator when we reach max analysis depth to save
compile time. After doing this, we also cache them as being equivalent.

Now, imagine the following situation:
- `A` is proved equivalent to `B`;
- `C` is proved equivalent to `D`;
- Comparison of `A` against `D` is proved non-zero;
- Comparison of `B` against `C` reaches max depth (and gets cached as equivalence).

Now, before the invocation of compare(`B`, `C`), `A` and `D` belonged
to different equivalence classes, and their comparison returned non-zero.
After the the invocation of compare(`B`, `C`), equivalence classes get merged
and `A`, `B`, `C` and `D` all fall into the same equivalence class. So the comparator
will change its behavior for couple `A` and `D`, with weird consequences following it.
This comparator is finally used in `std::stable_sort`, and this behavior change
makes it crash (looks like it's causing a memory corruption).

Solution: this patch changes `CompareSCEVComplexity` to return `None`
when the max depth is reached. So in this case, we do not cache these SCEVs
(and their parents in the tree) as being equivalent.

Differential Revision: https://reviews.llvm.org/D94654
Reviewed By: lebedev.ri
This commit is contained in:
Max Kazantsev 2021-01-29 12:08:34 +07:00
parent e86b7c773c
commit 6a837450cc
2 changed files with 133 additions and 20 deletions

View File

@ -689,11 +689,13 @@ CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
// Return negative, zero, or positive, if LHS is less than, equal to, or greater
// than RHS, respectively. A three-way result allows recursive comparisons to be
// more efficient.
static int CompareSCEVComplexity(
EquivalenceClasses<const SCEV *> &EqCacheSCEV,
EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
DominatorTree &DT, unsigned Depth = 0) {
// If the max analysis depth was reached, return None, assuming we do not know
// if they are equivalent for sure.
static Optional<int>
CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, const SCEV *LHS,
const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
if (LHS == RHS)
return 0;
@ -703,8 +705,12 @@ static int CompareSCEVComplexity(
if (LType != RType)
return (int)LType - (int)RType;
if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
if (EqCacheSCEV.isEquivalent(LHS, RHS))
return 0;
if (Depth > MaxSCEVCompareDepth)
return None;
// Aside from the getSCEVType() ordering, the particular ordering
// isn't very important except that it's beneficial to be consistent,
// so that (a + b) and (b + a) don't end up as different expressions.
@ -759,9 +765,9 @@ static int CompareSCEVComplexity(
// Lexicographically compare.
for (unsigned i = 0; i != LNumOps; ++i) {
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LA->getOperand(i), RA->getOperand(i), DT,
Depth + 1);
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LA->getOperand(i), RA->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
@ -784,9 +790,9 @@ static int CompareSCEVComplexity(
return (int)LNumOps - (int)RNumOps;
for (unsigned i = 0; i != LNumOps; ++i) {
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(i), RC->getOperand(i), DT,
Depth + 1);
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(i), RC->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
@ -799,8 +805,8 @@ static int CompareSCEVComplexity(
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
// Lexicographically compare udiv expressions.
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
RC->getLHS(), DT, Depth + 1);
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
RC->getLHS(), DT, Depth + 1);
if (X != 0)
return X;
X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
@ -818,9 +824,9 @@ static int CompareSCEVComplexity(
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
// Compare cast expressions by operand.
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(), RC->getOperand(), DT,
Depth + 1);
auto X =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
RC->getOperand(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
@ -847,19 +853,25 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
EquivalenceClasses<const SCEV *> EqCacheSCEV;
EquivalenceClasses<const Value *> EqCacheValue;
// Whether LHS has provably less complexity than RHS.
auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
auto Complexity =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
return Complexity && *Complexity < 0;
};
if (Ops.size() == 2) {
// This is the common case, which also happens to be trivially simple.
// Special case it.
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
if (IsLessComplex(RHS, LHS))
std::swap(LHS, RHS);
return;
}
// Do the rough sort by complexity.
llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) <
0;
return IsLessComplex(LHS, RHS);
});
// Now that we are sorted by complexity, group elements of the same

View File

@ -0,0 +1,101 @@
; RUN: opt -S -loop-reduce < %s | FileCheck %s
source_filename = "./simple.ll"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2"
target triple = "x86_64-unknown-linux-gnu"
; CHECK-LABEL: test
define void @test() {
bb:
br label %bb1
bb1: ; preds = %bb1, %bb
%tmp = phi i32 [ undef, %bb ], [ %tmp87, %bb1 ]
%tmp2 = phi i32 [ undef, %bb ], [ %tmp86, %bb1 ]
%tmp3 = mul i32 %tmp, undef
%tmp4 = xor i32 %tmp3, -1
%tmp5 = add i32 %tmp, %tmp4
%tmp6 = add i32 %tmp2, -1
%tmp7 = add i32 %tmp5, %tmp6
%tmp8 = mul i32 %tmp7, %tmp3
%tmp9 = xor i32 %tmp8, -1
%tmp10 = add i32 %tmp7, %tmp9
%tmp11 = add i32 %tmp10, undef
%tmp12 = mul i32 %tmp11, %tmp8
%tmp13 = xor i32 %tmp12, -1
%tmp14 = add i32 %tmp11, %tmp13
%tmp15 = add i32 %tmp14, undef
%tmp16 = mul i32 %tmp15, %tmp12
%tmp17 = add i32 %tmp15, undef
%tmp18 = add i32 %tmp17, undef
%tmp19 = mul i32 %tmp18, %tmp16
%tmp20 = xor i32 %tmp19, -1
%tmp21 = add i32 %tmp18, %tmp20
%tmp22 = add i32 %tmp21, undef
%tmp23 = mul i32 %tmp22, %tmp19
%tmp24 = xor i32 %tmp23, -1
%tmp25 = add i32 %tmp22, %tmp24
%tmp26 = add i32 %tmp25, undef
%tmp27 = mul i32 %tmp26, %tmp23
%tmp28 = xor i32 %tmp27, -1
%tmp29 = add i32 %tmp26, %tmp28
%tmp30 = add i32 %tmp29, undef
%tmp31 = mul i32 %tmp30, %tmp27
%tmp32 = xor i32 %tmp31, -1
%tmp33 = add i32 %tmp30, %tmp32
%tmp34 = add i32 %tmp33, undef
%tmp35 = mul i32 %tmp34, %tmp31
%tmp36 = xor i32 %tmp35, -1
%tmp37 = add i32 %tmp34, %tmp36
%tmp38 = add i32 %tmp2, -9
%tmp39 = add i32 %tmp37, %tmp38
%tmp40 = mul i32 %tmp39, %tmp35
%tmp41 = xor i32 %tmp40, -1
%tmp42 = add i32 %tmp39, %tmp41
%tmp43 = add i32 %tmp42, undef
%tmp44 = mul i32 %tmp43, %tmp40
%tmp45 = xor i32 %tmp44, -1
%tmp46 = add i32 %tmp43, %tmp45
%tmp47 = add i32 %tmp46, undef
%tmp48 = mul i32 %tmp47, %tmp44
%tmp49 = xor i32 %tmp48, -1
%tmp50 = add i32 %tmp47, %tmp49
%tmp51 = add i32 %tmp50, undef
%tmp52 = mul i32 %tmp51, %tmp48
%tmp53 = xor i32 %tmp52, -1
%tmp54 = add i32 %tmp51, %tmp53
%tmp55 = add i32 %tmp54, undef
%tmp56 = mul i32 %tmp55, %tmp52
%tmp57 = xor i32 %tmp56, -1
%tmp58 = add i32 %tmp55, %tmp57
%tmp59 = add i32 %tmp2, -14
%tmp60 = add i32 %tmp58, %tmp59
%tmp61 = mul i32 %tmp60, %tmp56
%tmp62 = xor i32 %tmp61, -1
%tmp63 = add i32 %tmp60, %tmp62
%tmp64 = add i32 %tmp63, undef
%tmp65 = mul i32 %tmp64, %tmp61
%tmp66 = xor i32 %tmp65, -1
%tmp67 = add i32 %tmp64, %tmp66
%tmp68 = add i32 %tmp67, undef
%tmp69 = mul i32 %tmp68, %tmp65
%tmp70 = xor i32 %tmp69, -1
%tmp71 = add i32 %tmp68, %tmp70
%tmp72 = add i32 %tmp71, undef
%tmp73 = mul i32 %tmp72, %tmp69
%tmp74 = xor i32 %tmp73, -1
%tmp75 = add i32 %tmp72, %tmp74
%tmp76 = add i32 %tmp75, undef
%tmp77 = mul i32 %tmp76, %tmp73
%tmp78 = xor i32 %tmp77, -1
%tmp79 = add i32 %tmp76, %tmp78
%tmp80 = add i32 %tmp79, undef
%tmp81 = mul i32 %tmp80, %tmp77
%tmp82 = xor i32 %tmp81, -1
%tmp83 = add i32 %tmp80, %tmp82
%tmp84 = add i32 %tmp83, undef
%tmp85 = add i32 %tmp84, undef
%tmp86 = add i32 %tmp2, -21
%tmp87 = add i32 %tmp85, %tmp86
br label %bb1
}