diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 0a095c9fa23..6c944403ecb 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -45742,11 +45742,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, unsigned NumElems = VT.getVectorNumElements(); EVT ScalarVT = VT.getVectorElementType(); - if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && NumElems >= 2)) + if (ScalarVT == MVT::i1 || NumElems < 2) return SDValue(); // InScalarVT is the intermediate type in AVG pattern and it should be greater - // than the original input type (i8/i16). + // than the original input type. EVT InScalarVT = InVT.getVectorElementType(); if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits()) return SDValue(); @@ -45764,12 +45764,14 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, // %6 = trunc %5 to // // In AVX512, the last instruction can also be a trunc store. - if (In.getOpcode() != ISD::SRL) + + // Shift type (lshr or ashr) doesn't affect the result, allow both. + if (In.getOpcode() != ISD::SRL && In.getOpcode() != ISD::SRA) return SDValue(); // A lambda checking the given SDValue is a constant vector and each element // is in the range [Min, Max]. - auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) { + auto IsConstVectorInRange = [](SDValue V, uint64_t Min, uint64_t Max) { return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) { return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max)); }); @@ -45788,9 +45790,80 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, Operands[0] = LHS.getOperand(0); Operands[1] = LHS.getOperand(1); - auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL, - ArrayRef Ops) { - return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops); + // Signed-ness of Operands[0..1] (true if sign-extended) + bool OpSign[2]{false, false}; + + auto AVGBuilder = [&](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + EVT VT = Ops[0].getValueType(); + EVT ScalarVT = VT.getVectorElementType(); + // Legal X86ISD::AVG types + if (ScalarVT == MVT::i8 || ScalarVT == MVT::i16) { + // Emulate signed or mixed AVG op via unsigned AVG + if (OpSign[0] || OpSign[1]) { + unsigned SignM = ScalarVT == MVT::i8 ? 0x80 : 0x8000; + SDValue SignVal = DAG.getConstant(SignM, DL, VT); + SDValue Op0 = Ops[0]; + SDValue Op1 = Ops[1]; + if (OpSign[0]) + Op0 = DAG.getNode(ISD::XOR, DL, VT, Ops[0], SignVal); + if (OpSign[1]) + Op1 = DAG.getNode(ISD::XOR, DL, VT, Ops[1], SignVal); + SDValue R = DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); + + // Signed AVG + if (OpSign[0] && OpSign[1]) + return DAG.getNode(ISD::XOR, DL, VT, R, SignVal); + + // Mixed signed-unsigned AVG + SignVal = DAG.getConstant(SignM >> 1, DL, VT); + return DAG.getNode(ISD::SUB, DL, VT, R, SignVal); + } + + // Standard unsigned AVG + return DAG.getNode(X86ISD::AVG, DL, VT, Ops); + } + + // Emulate AVG with more effective in-lane in-type algorithm + SDValue Operands[2]; + SDValue One = DAG.getConstant(1, DL, VT); + for (int i = 0; i < 2; i++) { + unsigned Op = OpSign[i] ? ISD::SRA : ISD::SRL; + Operands[i] = DAG.getNode(Op, DL, VT, Ops[i], One); + } + + // Compute carry of the 1-bit addition of least significant bits + SDValue Carry; + BuildVectorSDNode *BV = dyn_cast(Ops[1]); + if (BV && BV->isConstant()) { + bool AllOdds = true; + bool AllEven = true; + for (SDValue Op : BV->ops()) { + if (ConstantSDNode *C = dyn_cast(Op)) { + if (C->getAPIntValue()[0]) + AllEven = false; + else + AllOdds = false; + } else { + BV = nullptr; + break; + } + } + if (BV && AllOdds) + Carry = One; + if (BV && AllEven) + Carry = DAG.getNode(ISD::AND, DL, VT, Ops[0], One); + } + if (!Carry && Subtarget.hasAVX512()) + if (ScalarVT == MVT::i32 || ScalarVT == MVT::i64) + Carry = DAG.getNode(X86ISD::VPTERNLOG, DL, VT, Ops[0], Ops[1], One, + DAG.getTargetConstant(0xA8, DL, MVT::i8)); + if (!Carry) { + Carry = DAG.getNode(ISD::OR, DL, VT, Ops); + Carry = DAG.getNode(ISD::AND, DL, VT, Carry, One); + } + Carry = DAG.getNode(ISD::ADD, DL, VT, Carry, Operands[1]); + return DAG.getNode(ISD::ADD, DL, VT, Operands[0], Carry); }; auto AVGSplitter = [&](SDValue Op0, SDValue Op1) { @@ -45817,12 +45890,15 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, }; // Take care of the case when one of the operands is a constant vector whose - // element is in the range [1, 256]. - if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) && - Operands[0].getOpcode() == ISD::ZERO_EXTEND && + // element is in the range [1, 256] (for i8). TODO: support negative values. + unsigned SBW = ScalarVT.getSizeInBits(); + unsigned OpX = Operands[0].getOpcode(); + if (IsConstVectorInRange(Operands[1], 1, SBW < 64 ? 1ull << SBW : -1) && + (OpX == ISD::ZERO_EXTEND || OpX == ISD::SIGN_EXTEND) && Operands[0].getOperand(0).getValueType() == VT) { // The pattern is detected. Subtract one from the constant vector, then // demote it and emit X86ISD::AVG instruction. + OpSign[0] = (OpX == ISD::SIGN_EXTEND); SDValue VecOnes = DAG.getConstant(1, DL, InVT); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); @@ -45866,10 +45942,12 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, // Check if Operands[0] and Operands[1] are results of type promotion. for (int j = 0; j < 2; ++j) if (Operands[j].getValueType() != VT) { - if (Operands[j].getOpcode() != ISD::ZERO_EXTEND || + unsigned Op = Operands[j].getOpcode(); + if ((Op != ISD::ZERO_EXTEND && Op != ISD::SIGN_EXTEND) || Operands[j].getOperand(0).getValueType() != VT) return SDValue(); Operands[j] = Operands[j].getOperand(0); + OpSign[j] = Op == ISD::SIGN_EXTEND; } // The pattern is detected, emit X86ISD::AVG instruction(s).