1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 04:02:41 +01:00

[InstCombine] avoid infinite loops from min/max canonicalization

The intrinsics have an extra chunk of known bits logic
compared to the normal cmp+select idiom. That allows
folding the icmp in each case to something better, but
that then opposes the canonical form of min/max that
we try to form for a select.

I'm carving out a narrow exception to preserve all
existing regression tests while avoiding the inf-loop.
It seems unlikely that this is the only bug like this
left, but this should fix:
https://llvm.org/PR51419

(cherry picked from commit b267d3ce8defa092600bda717ff18440d002f316)
This commit is contained in:
Sanjay Patel 2021-08-10 14:30:43 -04:00 committed by Tom Stellard
parent 5a14ea148e
commit 4345e14522
2 changed files with 138 additions and 46 deletions

View File

@ -5158,6 +5158,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
if (!isa<Constant>(Op1) && Op1Min == Op1Max) if (!isa<Constant>(Op1) && Op1Min == Op1Max)
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
// min/max canonical compare with some other compare. That could lead to
// conflict with select canonicalization and infinite looping.
// FIXME: This constraint may go away if min/max intrinsics are canonical.
auto isMinMaxCmp = [&](Instruction &Cmp) {
if (!Cmp.hasOneUse())
return false;
Value *A, *B;
SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor;
if (!SelectPatternResult::isMinOrMax(SPF))
return false;
return match(Op0, m_MaxOrMin(m_Value(), m_Value())) ||
match(Op1, m_MaxOrMin(m_Value(), m_Value()));
};
if (!isMinMaxCmp(I)) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
// A <u C -> A == C-1 if min(A)+1 == C
if (*CmpC == Op0Min + 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
// X <u C --> X == 0, if the number of zero bits in the bottom of X
// exceeds the log2 of C.
if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
Constant::getNullValue(Op1->getType()));
}
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
// A >u C -> A == C+1 if max(a)-1 == C
if (*CmpC == Op0Max - 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
// X >u C --> X != 0, if the number of zero bits in the bottom of X
// exceeds the log2 of C.
if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
return new ICmpInst(ICmpInst::ICMP_NE, Op0,
Constant::getNullValue(Op1->getType()));
}
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
}
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
}
break;
}
}
}
// Based on the range information we know about the LHS, see if we can // Based on the range information we know about the LHS, see if we can
// simplify this comparison. For example, (x&4) < 8 is always true. // simplify this comparison. For example, (x&4) < 8 is always true.
switch (Pred) { switch (Pred) {
@ -5219,21 +5296,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
// A <u C -> A == C-1 if min(A)+1 == C
if (*CmpC == Op0Min + 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
// X <u C --> X == 0, if the number of zero bits in the bottom of X
// exceeds the log2 of C.
if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
Constant::getNullValue(Op1->getType()));
}
break; break;
} }
case ICmpInst::ICMP_UGT: { case ICmpInst::ICMP_UGT: {
@ -5241,21 +5303,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
// A >u C -> A == C+1 if max(a)-1 == C
if (*CmpC == Op0Max - 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
// X >u C --> X != 0, if the number of zero bits in the bottom of X
// exceeds the log2 of C.
if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
return new ICmpInst(ICmpInst::ICMP_NE, Op0,
Constant::getNullValue(Op1->getType()));
}
break; break;
} }
case ICmpInst::ICMP_SLT: { case ICmpInst::ICMP_SLT: {
@ -5263,14 +5310,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
}
break; break;
} }
case ICmpInst::ICMP_SGT: { case ICmpInst::ICMP_SGT: {
@ -5278,14 +5317,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
}
break; break;
} }
case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SGE:

View File

@ -192,3 +192,64 @@ define <3 x i5> @umax_select_const(<3 x i1> %b, <3 x i5> %x) {
%c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> <i5 5, i5 8, i5 1>, <3 x i5> %s) %c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> <i5 5, i5 8, i5 1>, <3 x i5> %s)
ret <3 x i5> %c ret <3 x i5> %c
} }
declare i32 @llvm.smax.i32(i32, i32);
declare i32 @llvm.smin.i32(i32, i32);
declare i8 @llvm.umax.i8(i8, i8);
declare i8 @llvm.umin.i8(i8, i8);
; Each of the following 4 tests would infinite loop because
; we had conflicting transforms for icmp and select using
; known bits.
define i32 @smax_smin(i32 %x) {
; CHECK-LABEL: @smax_smin(
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 0)
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[M]], 1
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 1
; CHECK-NEXT: ret i32 [[S]]
;
%m = call i32 @llvm.smax.i32(i32 %x, i32 0)
%c = icmp slt i32 %x, 1
%s = select i1 %c, i32 %m, i32 1
ret i32 %s
}
define i32 @smin_smax(i32 %x) {
; CHECK-LABEL: @smin_smax(
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 -1)
; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[M]], -2
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 -2
; CHECK-NEXT: ret i32 [[S]]
;
%m = call i32 @llvm.smin.i32(i32 %x, i32 -1)
%c = icmp sgt i32 %x, -2
%s = select i1 %c, i32 %m, i32 -2
ret i32 %s
}
define i8 @umax_umin(i8 %x) {
; CHECK-LABEL: @umax_umin(
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 -128)
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[M]], -127
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 -127
; CHECK-NEXT: ret i8 [[S]]
;
%m = call i8 @llvm.umax.i8(i8 %x, i8 128)
%c = icmp ult i8 %x, 129
%s = select i1 %c, i8 %m, i8 129
ret i8 %s
}
define i8 @umin_umax(i8 %x) {
; CHECK-LABEL: @umin_umax(
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 127)
; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[M]], 126
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 126
; CHECK-NEXT: ret i8 [[S]]
;
%m = call i8 @llvm.umin.i8(i8 %x, i8 127)
%c = icmp ugt i8 %x, 126
%s = select i1 %c, i8 %m, i8 126
ret i8 %s
}