From 4345e145226757d70f0ed16dd650018ad3208af8 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Tue, 10 Aug 2021 14:30:43 -0400 Subject: [PATCH] [InstCombine] avoid infinite loops from min/max canonicalization The intrinsics have an extra chunk of known bits logic compared to the normal cmp+select idiom. That allows folding the icmp in each case to something better, but that then opposes the canonical form of min/max that we try to form for a select. I'm carving out a narrow exception to preserve all existing regression tests while avoiding the inf-loop. It seems unlikely that this is the only bug like this left, but this should fix: https://llvm.org/PR51419 (cherry picked from commit b267d3ce8defa092600bda717ff18440d002f316) --- .../InstCombine/InstCombineCompares.cpp | 123 +++++++++++------- test/Transforms/InstCombine/select-min-max.ll | 61 +++++++++ 2 files changed, 138 insertions(+), 46 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 2b0ef0c5f2c..c5e14ebf3ae 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5158,6 +5158,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { if (!isa(Op1) && Op1Min == Op1Max) return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); + // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a + // min/max canonical compare with some other compare. That could lead to + // conflict with select canonicalization and infinite looping. + // FIXME: This constraint may go away if min/max intrinsics are canonical. + auto isMinMaxCmp = [&](Instruction &Cmp) { + if (!Cmp.hasOneUse()) + return false; + Value *A, *B; + SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor; + if (!SelectPatternResult::isMinOrMax(SPF)) + return false; + return match(Op0, m_MaxOrMin(m_Value(), m_Value())) || + match(Op1, m_MaxOrMin(m_Value(), m_Value())); + }; + if (!isMinMaxCmp(I)) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_ULT: { + if (Op1Min == Op0Max) // A A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A A == C-1 if min(A)+1 == C + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_SLT: { + if (Op1Min == Op0Max) // A A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + } + break; + } + case ICmpInst::ICMP_SGT: { + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + } + break; + } + } + } + // Based on the range information we know about the LHS, see if we can // simplify this comparison. For example, (x&4) < 8 is always true. switch (Pred) { @@ -5219,21 +5296,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.uge(Op1Max)) // A false if min(A) >= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A A == C-1 if min(A)+1 == C - if (*CmpC == Op0Min + 1) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC - 1)); - // X X == 0, if the number of zero bits in the bottom of X - // exceeds the log2 of C. - if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Constant::getNullValue(Op1->getType())); - } break; } case ICmpInst::ICMP_UGT: { @@ -5241,21 +5303,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A >u C -> A == C+1 if max(a)-1 == C - if (*CmpC == Op0Max - 1) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC + 1)); - // X >u C --> X != 0, if the number of zero bits in the bottom of X - // exceeds the log2 of C. - if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, - Constant::getNullValue(Op1->getType())); - } break; } case ICmpInst::ICMP_SLT: { @@ -5263,14 +5310,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A false if min(A) >= max(C) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - if (*CmpC == Op0Min + 1) // A A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC - 1)); - } break; } case ICmpInst::ICMP_SGT: { @@ -5278,14 +5317,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC + 1)); - } break; } case ICmpInst::ICMP_SGE: diff --git a/test/Transforms/InstCombine/select-min-max.ll b/test/Transforms/InstCombine/select-min-max.ll index ae6ee317059..3cd697853c3 100644 --- a/test/Transforms/InstCombine/select-min-max.ll +++ b/test/Transforms/InstCombine/select-min-max.ll @@ -192,3 +192,64 @@ define <3 x i5> @umax_select_const(<3 x i1> %b, <3 x i5> %x) { %c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> , <3 x i5> %s) ret <3 x i5> %c } + +declare i32 @llvm.smax.i32(i32, i32); +declare i32 @llvm.smin.i32(i32, i32); +declare i8 @llvm.umax.i8(i8, i8); +declare i8 @llvm.umin.i8(i8, i8); + +; Each of the following 4 tests would infinite loop because +; we had conflicting transforms for icmp and select using +; known bits. + +define i32 @smax_smin(i32 %x) { +; CHECK-LABEL: @smax_smin( +; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 0) +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[M]], 1 +; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 1 +; CHECK-NEXT: ret i32 [[S]] +; + %m = call i32 @llvm.smax.i32(i32 %x, i32 0) + %c = icmp slt i32 %x, 1 + %s = select i1 %c, i32 %m, i32 1 + ret i32 %s +} + +define i32 @smin_smax(i32 %x) { +; CHECK-LABEL: @smin_smax( +; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 -1) +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[M]], -2 +; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 -2 +; CHECK-NEXT: ret i32 [[S]] +; + %m = call i32 @llvm.smin.i32(i32 %x, i32 -1) + %c = icmp sgt i32 %x, -2 + %s = select i1 %c, i32 %m, i32 -2 + ret i32 %s +} + +define i8 @umax_umin(i8 %x) { +; CHECK-LABEL: @umax_umin( +; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 -128) +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[M]], -127 +; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 -127 +; CHECK-NEXT: ret i8 [[S]] +; + %m = call i8 @llvm.umax.i8(i8 %x, i8 128) + %c = icmp ult i8 %x, 129 + %s = select i1 %c, i8 %m, i8 129 + ret i8 %s +} + +define i8 @umin_umax(i8 %x) { +; CHECK-LABEL: @umin_umax( +; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 127) +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[M]], 126 +; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 126 +; CHECK-NEXT: ret i8 [[S]] +; + %m = call i8 @llvm.umin.i8(i8 %x, i8 127) + %c = icmp ugt i8 %x, 126 + %s = select i1 %c, i8 %m, i8 126 + ret i8 %s +}