From a17954e6f1e566de73e2fdea4a7880bba26f8a62 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Thu, 29 Apr 2021 13:06:26 -0700 Subject: [PATCH] Revert "Generalize getInvertibleOperand recurrence handling slightly" This reverts commit 0c01b37eeb18a51a7e9c9153330d8009de0f600e while a problem reported is investigated. --- lib/Analysis/ValueTracking.cpp | 57 ++++++++++--------- .../Analysis/ValueTracking/known-non-equal.ll | 12 ++-- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index df3531ad28d..3e46e11a652 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -2521,31 +2521,26 @@ bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) { return isKnownNonZero(V, DemandedElts, Depth, Q); } -/// If the pair of operators are the same invertible function, return the -/// the operands of the function corresponding to each input. Otherwise, -/// return None. An invertible function is one that is 1-to-1 and maps -/// every input value to exactly one output value. This is equivalent to -/// saying that Op1 and Op2 are equal exactly when the specified pair of -/// operands are equal, (except that Op1 and Op2 may be poison more often.) -static Optional> -getInvertibleOperands(const Operator *Op1, - const Operator *Op2) { +/// If the pair of operators are the same invertible function of a single +/// operand return the index of that operand. Otherwise, return None. An +/// invertible function is one that is 1-to-1 and maps every input value +/// to exactly one output value. This is equivalent to saying that Op1 +/// and Op2 are equal exactly when the specified pair of operands are equal, +/// (except that Op1 and Op2 may be poison more often.) +static Optional getInvertibleOperand(const Operator *Op1, + const Operator *Op2) { if (Op1->getOpcode() != Op2->getOpcode()) return None; - auto getOperands = [&](unsigned OpNum) -> auto { - return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum)); - }; - switch (Op1->getOpcode()) { default: break; case Instruction::Add: case Instruction::Sub: if (Op1->getOperand(0) == Op2->getOperand(0)) - return getOperands(1); + return 1; if (Op1->getOperand(1) == Op2->getOperand(1)) - return getOperands(0); + return 0; break; case Instruction::Mul: { // invertible if A * B == (A * B) mod 2^N where A, and B are integers @@ -2561,7 +2556,7 @@ getInvertibleOperands(const Operator *Op1, if (Op1->getOperand(1) == Op2->getOperand(1) && isa(Op1->getOperand(1)) && !cast(Op1->getOperand(1))->isZero()) - return getOperands(0); + return 0; break; } case Instruction::Shl: { @@ -2574,7 +2569,7 @@ getInvertibleOperands(const Operator *Op1, break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return getOperands(0); + return 0; break; } case Instruction::AShr: @@ -2585,13 +2580,13 @@ getInvertibleOperands(const Operator *Op1, break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return getOperands(0); + return 0; break; } case Instruction::SExt: case Instruction::ZExt: if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType()) - return getOperands(0); + return 0; break; case Instruction::PHI: { const PHINode *PN1 = cast(Op1); @@ -2609,12 +2604,18 @@ getInvertibleOperands(const Operator *Op1, !matchSimpleRecurrence(PN2, BO2, Start2, Step2)) break; - auto Values = getInvertibleOperands(cast(BO1), - cast(BO2)); - if (!Values) + Optional Idx = getInvertibleOperand(cast(BO1), + cast(BO2)); + if (!Idx || *Idx != 0) break; - assert(Values->first == PN1 && Values->second == PN2); - return std::make_pair(Start1, Start2); + assert(BO1->getOperand(*Idx) == PN1 && BO2->getOperand(*Idx) == PN2); + + // Phi operands might not be in the same order. TODO: generalize + // interface to return pair of operands. + if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2) + return 1; + if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2) + return 0; } } return None; @@ -2711,9 +2712,11 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth, auto *O1 = dyn_cast(V1); auto *O2 = dyn_cast(V2); if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) { - if (auto Values = getInvertibleOperands(O1, O2)) - return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q); - + if (Optional Opt = getInvertibleOperand(O1, O2)) { + unsigned Idx = *Opt; + return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx), + Depth + 1, Q); + } if (const PHINode *PN1 = dyn_cast(V1)) { const PHINode *PN2 = cast(V2); // FIXME: This is missing a generalization to handle the case where one is diff --git a/test/Analysis/ValueTracking/known-non-equal.ll b/test/Analysis/ValueTracking/known-non-equal.ll index f3e3b6dbfaf..a41d228fd40 100644 --- a/test/Analysis/ValueTracking/known-non-equal.ll +++ b/test/Analysis/ValueTracking/known-non-equal.ll @@ -736,7 +736,8 @@ define i1 @recurrence_add_op_order(i8 %A) { ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] +; CHECK-NEXT: ret i1 [[RES]] ; entry: %B = add i8 %A, 1 @@ -807,7 +808,8 @@ define i1 @recurrence_add_phi_different_order1(i8 %A) { ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] +; CHECK-NEXT: ret i1 [[RES]] ; entry: %B = add i8 %A, 1 @@ -841,7 +843,8 @@ define i1 @recurrence_add_phi_different_order2(i8 %A) { ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] +; CHECK-NEXT: ret i1 [[RES]] ; entry: %B = add i8 %A, 1 @@ -976,7 +979,8 @@ define i1 @recurrence_sub_op_order(i8 %A) { ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] +; CHECK-NEXT: ret i1 [[RES]] ; entry: %B = add i8 %A, 1