diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 62ee7d00780..eef56c8645f 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -711,6 +711,7 @@ public: Value *A, Value *B, Instruction &Outer, SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); + Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI); Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd); diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index ce473410f4c..087586ede80 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1165,9 +1165,8 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, /// /// We can't replace %sel with %add unless we strip away the flags. /// TODO: Wrapping flags could be preserved in some cases with better analysis. -static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, - const SimplifyQuery &Q, - InstCombiner &IC) { +Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, + ICmpInst &Cmp) { if (!Cmp.isEquality()) return nullptr; @@ -1179,18 +1178,20 @@ static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, Swapped = true; } - // In X == Y ? f(X) : Z, try to evaluate f(X) and replace the operand. - // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that - // would lead to an infinite replacement cycle. + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Make sure Y cannot be undef though, as we might pick different values for + // undef in the icmp and in f(Y). Additionally, take care to avoid replacing + // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite + // replacement cycle. Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (TrueVal != CmpLHS) - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + if (TrueVal != CmpLHS && isGuaranteedNotToBeUndefOrPoison(CmpRHS, &Sel, &DT)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, /* AllowRefinement */ true)) - return IC.replaceOperand(Sel, Swapped ? 2 : 1, V); - if (TrueVal != CmpRHS) - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + return replaceOperand(Sel, Swapped ? 2 : 1, V); + if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, &Sel, &DT)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, /* AllowRefinement */ true)) - return IC.replaceOperand(Sel, Swapped ? 2 : 1, V); + return replaceOperand(Sel, Swapped ? 2 : 1, V); auto *FalseInst = dyn_cast(FalseVal); if (!FalseInst) @@ -1215,11 +1216,11 @@ static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, /* AllowRefinement */ false) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, /* AllowRefinement */ false) == TrueVal) { - return IC.replaceInstUsesWith(Sel, FalseVal); + return replaceInstUsesWith(Sel, FalseVal); } // Restore poison-generating flags if the transform did not apply. @@ -1455,7 +1456,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { - if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI, SQ, *this)) + if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) return NewSel; if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) diff --git a/test/Transforms/InstCombine/select-binop-cmp.ll b/test/Transforms/InstCombine/select-binop-cmp.ll index aa450f8af8b..c4a9d0941b9 100644 --- a/test/Transforms/InstCombine/select-binop-cmp.ll +++ b/test/Transforms/InstCombine/select-binop-cmp.ll @@ -564,10 +564,12 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) ret <2 x i8> %C } +; Folding this would only be legal if we sanitized undef to 0. define <2 x i8> @select_xor_icmp_vec_undef(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { ; CHECK-LABEL: @select_xor_icmp_vec_undef( ; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[Z:%.*]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]] +; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]] ; CHECK-NEXT: ret <2 x i8> [[C]] ; %A = icmp eq <2 x i8> %x, diff --git a/test/Transforms/InstCombine/select.ll b/test/Transforms/InstCombine/select.ll index b7c4cb5c642..df506477eed 100644 --- a/test/Transforms/InstCombine/select.ll +++ b/test/Transforms/InstCombine/select.ll @@ -2641,8 +2641,8 @@ define i8 @select_replacement_add_nuw(i8 %x, i8 %y) { ret i8 %sel } -define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) { -; CHECK-LABEL: @select_replacement_sub( +define i8 @select_replacement_sub_noundef(i8 %x, i8 noundef %y, i8 %z) { +; CHECK-LABEL: @select_replacement_sub_noundef( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] @@ -2653,11 +2653,43 @@ define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) { ret i8 %sel } +; TODO: The transform is also safe without noundef. +define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) { +; CHECK-LABEL: @select_replacement_sub( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[Y]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[SUB]], i8 [[Z:%.*]] +; CHECK-NEXT: ret i8 [[SEL]] +; + %cmp = icmp eq i8 %x, %y + %sub = sub i8 %x, %y + %sel = select i1 %cmp, i8 %sub, i8 %z + ret i8 %sel +} + +define i8 @select_replacement_shift_noundef(i8 %x, i8 %y, i8 %z) { +; CHECK-LABEL: @select_replacement_shift_noundef( +; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 [[X:%.*]], 1 +; CHECK-NEXT: call void @use_i8(i8 noundef [[SHR]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHR]], [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[Z:%.*]] +; CHECK-NEXT: ret i8 [[SEL]] +; + %shr = lshr exact i8 %x, 1 + call void @use_i8(i8 noundef %shr) + %cmp = icmp eq i8 %shr, %y + %shl = shl i8 %y, 1 + %sel = select i1 %cmp, i8 %shl, i8 %z + ret i8 %sel +} + +; TODO: The transform is also safe without noundef. define i8 @select_replacement_shift(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @select_replacement_shift( ; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 [[X:%.*]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHR]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[Z:%.*]] +; CHECK-NEXT: [[SHL:%.*]] = shl i8 [[Y]], 1 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[SHL]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %shr = lshr exact i8 %x, 1 @@ -2694,4 +2726,5 @@ define i32 @select_replacement_loop2(i32 %arg, i32 %arg2) { } declare void @use(i1) +declare void @use_i8(i8) declare i32 @llvm.cttz.i32(i32, i1 immarg)