1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 02:33:06 +01:00

X86: expand detectAVGPattern()

Allow all integer widths in the pattern, allow ashr
Handle signed and mixed cases, allowing to replace truncation
This commit is contained in:
Nekotekina 2018-06-27 15:00:34 +03:00
parent 5ff8f4151c
commit 2ffa82223f

View File

@ -45742,11 +45742,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
unsigned NumElems = VT.getVectorNumElements(); unsigned NumElems = VT.getVectorNumElements();
EVT ScalarVT = VT.getVectorElementType(); EVT ScalarVT = VT.getVectorElementType();
if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && NumElems >= 2)) if (ScalarVT == MVT::i1 || NumElems < 2)
return SDValue(); return SDValue();
// InScalarVT is the intermediate type in AVG pattern and it should be greater // InScalarVT is the intermediate type in AVG pattern and it should be greater
// than the original input type (i8/i16). // than the original input type.
EVT InScalarVT = InVT.getVectorElementType(); EVT InScalarVT = InVT.getVectorElementType();
if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits()) if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits())
return SDValue(); return SDValue();
@ -45764,12 +45764,14 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
// %6 = trunc <N x i32> %5 to <N x i8> // %6 = trunc <N x i32> %5 to <N x i8>
// //
// In AVX512, the last instruction can also be a trunc store. // In AVX512, the last instruction can also be a trunc store.
if (In.getOpcode() != ISD::SRL)
// Shift type (lshr or ashr) doesn't affect the result, allow both.
if (In.getOpcode() != ISD::SRL && In.getOpcode() != ISD::SRA)
return SDValue(); return SDValue();
// A lambda checking the given SDValue is a constant vector and each element // A lambda checking the given SDValue is a constant vector and each element
// is in the range [Min, Max]. // is in the range [Min, Max].
auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) { auto IsConstVectorInRange = [](SDValue V, uint64_t Min, uint64_t Max) {
return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) { return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) {
return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max)); return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max));
}); });
@ -45788,9 +45790,80 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
Operands[0] = LHS.getOperand(0); Operands[0] = LHS.getOperand(0);
Operands[1] = LHS.getOperand(1); Operands[1] = LHS.getOperand(1);
auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL, // Signed-ness of Operands[0..1] (true if sign-extended)
ArrayRef<SDValue> Ops) { bool OpSign[2]{false, false};
return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
auto AVGBuilder = [&](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
EVT VT = Ops[0].getValueType();
EVT ScalarVT = VT.getVectorElementType();
// Legal X86ISD::AVG types
if (ScalarVT == MVT::i8 || ScalarVT == MVT::i16) {
// Emulate signed or mixed AVG op via unsigned AVG
if (OpSign[0] || OpSign[1]) {
unsigned SignM = ScalarVT == MVT::i8 ? 0x80 : 0x8000;
SDValue SignVal = DAG.getConstant(SignM, DL, VT);
SDValue Op0 = Ops[0];
SDValue Op1 = Ops[1];
if (OpSign[0])
Op0 = DAG.getNode(ISD::XOR, DL, VT, Ops[0], SignVal);
if (OpSign[1])
Op1 = DAG.getNode(ISD::XOR, DL, VT, Ops[1], SignVal);
SDValue R = DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1);
// Signed AVG
if (OpSign[0] && OpSign[1])
return DAG.getNode(ISD::XOR, DL, VT, R, SignVal);
// Mixed signed-unsigned AVG
SignVal = DAG.getConstant(SignM >> 1, DL, VT);
return DAG.getNode(ISD::SUB, DL, VT, R, SignVal);
}
// Standard unsigned AVG
return DAG.getNode(X86ISD::AVG, DL, VT, Ops);
}
// Emulate AVG with more effective in-lane in-type algorithm
SDValue Operands[2];
SDValue One = DAG.getConstant(1, DL, VT);
for (int i = 0; i < 2; i++) {
unsigned Op = OpSign[i] ? ISD::SRA : ISD::SRL;
Operands[i] = DAG.getNode(Op, DL, VT, Ops[i], One);
}
// Compute carry of the 1-bit addition of least significant bits
SDValue Carry;
BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Ops[1]);
if (BV && BV->isConstant()) {
bool AllOdds = true;
bool AllEven = true;
for (SDValue Op : BV->ops()) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
if (C->getAPIntValue()[0])
AllEven = false;
else
AllOdds = false;
} else {
BV = nullptr;
break;
}
}
if (BV && AllOdds)
Carry = One;
if (BV && AllEven)
Carry = DAG.getNode(ISD::AND, DL, VT, Ops[0], One);
}
if (!Carry && Subtarget.hasAVX512())
if (ScalarVT == MVT::i32 || ScalarVT == MVT::i64)
Carry = DAG.getNode(X86ISD::VPTERNLOG, DL, VT, Ops[0], Ops[1], One,
DAG.getTargetConstant(0xA8, DL, MVT::i8));
if (!Carry) {
Carry = DAG.getNode(ISD::OR, DL, VT, Ops);
Carry = DAG.getNode(ISD::AND, DL, VT, Carry, One);
}
Carry = DAG.getNode(ISD::ADD, DL, VT, Carry, Operands[1]);
return DAG.getNode(ISD::ADD, DL, VT, Operands[0], Carry);
}; };
auto AVGSplitter = [&](SDValue Op0, SDValue Op1) { auto AVGSplitter = [&](SDValue Op0, SDValue Op1) {
@ -45817,12 +45890,15 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
}; };
// Take care of the case when one of the operands is a constant vector whose // Take care of the case when one of the operands is a constant vector whose
// element is in the range [1, 256]. // element is in the range [1, 256] (for i8). TODO: support negative values.
if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) && unsigned SBW = ScalarVT.getSizeInBits();
Operands[0].getOpcode() == ISD::ZERO_EXTEND && unsigned OpX = Operands[0].getOpcode();
if (IsConstVectorInRange(Operands[1], 1, SBW < 64 ? 1ull << SBW : -1) &&
(OpX == ISD::ZERO_EXTEND || OpX == ISD::SIGN_EXTEND) &&
Operands[0].getOperand(0).getValueType() == VT) { Operands[0].getOperand(0).getValueType() == VT) {
// The pattern is detected. Subtract one from the constant vector, then // The pattern is detected. Subtract one from the constant vector, then
// demote it and emit X86ISD::AVG instruction. // demote it and emit X86ISD::AVG instruction.
OpSign[0] = (OpX == ISD::SIGN_EXTEND);
SDValue VecOnes = DAG.getConstant(1, DL, InVT); SDValue VecOnes = DAG.getConstant(1, DL, InVT);
Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
@ -45866,10 +45942,12 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
// Check if Operands[0] and Operands[1] are results of type promotion. // Check if Operands[0] and Operands[1] are results of type promotion.
for (int j = 0; j < 2; ++j) for (int j = 0; j < 2; ++j)
if (Operands[j].getValueType() != VT) { if (Operands[j].getValueType() != VT) {
if (Operands[j].getOpcode() != ISD::ZERO_EXTEND || unsigned Op = Operands[j].getOpcode();
if ((Op != ISD::ZERO_EXTEND && Op != ISD::SIGN_EXTEND) ||
Operands[j].getOperand(0).getValueType() != VT) Operands[j].getOperand(0).getValueType() != VT)
return SDValue(); return SDValue();
Operands[j] = Operands[j].getOperand(0); Operands[j] = Operands[j].getOperand(0);
OpSign[j] = Op == ISD::SIGN_EXTEND;
} }
// The pattern is detected, emit X86ISD::AVG instruction(s). // The pattern is detected, emit X86ISD::AVG instruction(s).