mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-21 18:22:53 +01:00
[InstCombine] avoid infinite loops from min/max canonicalization
The intrinsics have an extra chunk of known bits logic compared to the normal cmp+select idiom. That allows folding the icmp in each case to something better, but that then opposes the canonical form of min/max that we try to form for a select. I'm carving out a narrow exception to preserve all existing regression tests while avoiding the inf-loop. It seems unlikely that this is the only bug like this left, but this should fix: https://llvm.org/PR51419 (cherry picked from commit b267d3ce8defa092600bda717ff18440d002f316)
This commit is contained in:
parent
5a14ea148e
commit
4345e14522
@ -5158,6 +5158,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
|
||||
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
|
||||
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
|
||||
|
||||
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
|
||||
// min/max canonical compare with some other compare. That could lead to
|
||||
// conflict with select canonicalization and infinite looping.
|
||||
// FIXME: This constraint may go away if min/max intrinsics are canonical.
|
||||
auto isMinMaxCmp = [&](Instruction &Cmp) {
|
||||
if (!Cmp.hasOneUse())
|
||||
return false;
|
||||
Value *A, *B;
|
||||
SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor;
|
||||
if (!SelectPatternResult::isMinOrMax(SPF))
|
||||
return false;
|
||||
return match(Op0, m_MaxOrMin(m_Value(), m_Value())) ||
|
||||
match(Op1, m_MaxOrMin(m_Value(), m_Value()));
|
||||
};
|
||||
if (!isMinMaxCmp(I)) {
|
||||
switch (Pred) {
|
||||
default:
|
||||
break;
|
||||
case ICmpInst::ICMP_ULT: {
|
||||
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
// A <u C -> A == C-1 if min(A)+1 == C
|
||||
if (*CmpC == Op0Min + 1)
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC - 1));
|
||||
// X <u C --> X == 0, if the number of zero bits in the bottom of X
|
||||
// exceeds the log2 of C.
|
||||
if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
Constant::getNullValue(Op1->getType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_UGT: {
|
||||
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
// A >u C -> A == C+1 if max(a)-1 == C
|
||||
if (*CmpC == Op0Max - 1)
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC + 1));
|
||||
// X >u C --> X != 0, if the number of zero bits in the bottom of X
|
||||
// exceeds the log2 of C.
|
||||
if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0,
|
||||
Constant::getNullValue(Op1->getType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_SLT: {
|
||||
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC - 1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_SGT: {
|
||||
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC + 1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Based on the range information we know about the LHS, see if we can
|
||||
// simplify this comparison. For example, (x&4) < 8 is always true.
|
||||
switch (Pred) {
|
||||
@ -5219,21 +5296,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
|
||||
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
|
||||
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
|
||||
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
|
||||
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
// A <u C -> A == C-1 if min(A)+1 == C
|
||||
if (*CmpC == Op0Min + 1)
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC - 1));
|
||||
// X <u C --> X == 0, if the number of zero bits in the bottom of X
|
||||
// exceeds the log2 of C.
|
||||
if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
Constant::getNullValue(Op1->getType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_UGT: {
|
||||
@ -5241,21 +5303,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
|
||||
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
|
||||
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
|
||||
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
|
||||
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
// A >u C -> A == C+1 if max(a)-1 == C
|
||||
if (*CmpC == Op0Max - 1)
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC + 1));
|
||||
// X >u C --> X != 0, if the number of zero bits in the bottom of X
|
||||
// exceeds the log2 of C.
|
||||
if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0,
|
||||
Constant::getNullValue(Op1->getType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_SLT: {
|
||||
@ -5263,14 +5310,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
|
||||
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
|
||||
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
|
||||
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
|
||||
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC - 1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_SGT: {
|
||||
@ -5278,14 +5317,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
|
||||
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
|
||||
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
|
||||
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
|
||||
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
|
||||
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
|
||||
const APInt *CmpC;
|
||||
if (match(Op1, m_APInt(CmpC))) {
|
||||
if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
|
||||
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
|
||||
ConstantInt::get(Op1->getType(), *CmpC + 1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ICmpInst::ICMP_SGE:
|
||||
|
@ -192,3 +192,64 @@ define <3 x i5> @umax_select_const(<3 x i1> %b, <3 x i5> %x) {
|
||||
%c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> <i5 5, i5 8, i5 1>, <3 x i5> %s)
|
||||
ret <3 x i5> %c
|
||||
}
|
||||
|
||||
declare i32 @llvm.smax.i32(i32, i32);
|
||||
declare i32 @llvm.smin.i32(i32, i32);
|
||||
declare i8 @llvm.umax.i8(i8, i8);
|
||||
declare i8 @llvm.umin.i8(i8, i8);
|
||||
|
||||
; Each of the following 4 tests would infinite loop because
|
||||
; we had conflicting transforms for icmp and select using
|
||||
; known bits.
|
||||
|
||||
define i32 @smax_smin(i32 %x) {
|
||||
; CHECK-LABEL: @smax_smin(
|
||||
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 0)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[M]], 1
|
||||
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 1
|
||||
; CHECK-NEXT: ret i32 [[S]]
|
||||
;
|
||||
%m = call i32 @llvm.smax.i32(i32 %x, i32 0)
|
||||
%c = icmp slt i32 %x, 1
|
||||
%s = select i1 %c, i32 %m, i32 1
|
||||
ret i32 %s
|
||||
}
|
||||
|
||||
define i32 @smin_smax(i32 %x) {
|
||||
; CHECK-LABEL: @smin_smax(
|
||||
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 -1)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[M]], -2
|
||||
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 -2
|
||||
; CHECK-NEXT: ret i32 [[S]]
|
||||
;
|
||||
%m = call i32 @llvm.smin.i32(i32 %x, i32 -1)
|
||||
%c = icmp sgt i32 %x, -2
|
||||
%s = select i1 %c, i32 %m, i32 -2
|
||||
ret i32 %s
|
||||
}
|
||||
|
||||
define i8 @umax_umin(i8 %x) {
|
||||
; CHECK-LABEL: @umax_umin(
|
||||
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 -128)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[M]], -127
|
||||
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 -127
|
||||
; CHECK-NEXT: ret i8 [[S]]
|
||||
;
|
||||
%m = call i8 @llvm.umax.i8(i8 %x, i8 128)
|
||||
%c = icmp ult i8 %x, 129
|
||||
%s = select i1 %c, i8 %m, i8 129
|
||||
ret i8 %s
|
||||
}
|
||||
|
||||
define i8 @umin_umax(i8 %x) {
|
||||
; CHECK-LABEL: @umin_umax(
|
||||
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 127)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[M]], 126
|
||||
; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 126
|
||||
; CHECK-NEXT: ret i8 [[S]]
|
||||
;
|
||||
%m = call i8 @llvm.umin.i8(i8 %x, i8 127)
|
||||
%c = icmp ugt i8 %x, 126
|
||||
%s = select i1 %c, i8 %m, i8 126
|
||||
ret i8 %s
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user