1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-21 18:22:53 +01:00

X86: expand detectAVGPattern()

Allow all integer widths in the pattern, allow ashr
Handle signed and mixed cases, allowing to replace truncation
This commit is contained in:
Nekotekina 2018-06-27 15:00:34 +03:00
parent 5ff8f4151c
commit 2ffa82223f

View File

@ -45742,11 +45742,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
unsigned NumElems = VT.getVectorNumElements();
EVT ScalarVT = VT.getVectorElementType();
if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && NumElems >= 2))
if (ScalarVT == MVT::i1 || NumElems < 2)
return SDValue();
// InScalarVT is the intermediate type in AVG pattern and it should be greater
// than the original input type (i8/i16).
// than the original input type.
EVT InScalarVT = InVT.getVectorElementType();
if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits())
return SDValue();
@ -45764,12 +45764,14 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
// %6 = trunc <N x i32> %5 to <N x i8>
//
// In AVX512, the last instruction can also be a trunc store.
if (In.getOpcode() != ISD::SRL)
// Shift type (lshr or ashr) doesn't affect the result, allow both.
if (In.getOpcode() != ISD::SRL && In.getOpcode() != ISD::SRA)
return SDValue();
// A lambda checking the given SDValue is a constant vector and each element
// is in the range [Min, Max].
auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) {
auto IsConstVectorInRange = [](SDValue V, uint64_t Min, uint64_t Max) {
return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) {
return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max));
});
@ -45788,9 +45790,80 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
Operands[0] = LHS.getOperand(0);
Operands[1] = LHS.getOperand(1);
auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
// Signed-ness of Operands[0..1] (true if sign-extended)
bool OpSign[2]{false, false};
auto AVGBuilder = [&](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
EVT VT = Ops[0].getValueType();
EVT ScalarVT = VT.getVectorElementType();
// Legal X86ISD::AVG types
if (ScalarVT == MVT::i8 || ScalarVT == MVT::i16) {
// Emulate signed or mixed AVG op via unsigned AVG
if (OpSign[0] || OpSign[1]) {
unsigned SignM = ScalarVT == MVT::i8 ? 0x80 : 0x8000;
SDValue SignVal = DAG.getConstant(SignM, DL, VT);
SDValue Op0 = Ops[0];
SDValue Op1 = Ops[1];
if (OpSign[0])
Op0 = DAG.getNode(ISD::XOR, DL, VT, Ops[0], SignVal);
if (OpSign[1])
Op1 = DAG.getNode(ISD::XOR, DL, VT, Ops[1], SignVal);
SDValue R = DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1);
// Signed AVG
if (OpSign[0] && OpSign[1])
return DAG.getNode(ISD::XOR, DL, VT, R, SignVal);
// Mixed signed-unsigned AVG
SignVal = DAG.getConstant(SignM >> 1, DL, VT);
return DAG.getNode(ISD::SUB, DL, VT, R, SignVal);
}
// Standard unsigned AVG
return DAG.getNode(X86ISD::AVG, DL, VT, Ops);
}
// Emulate AVG with more effective in-lane in-type algorithm
SDValue Operands[2];
SDValue One = DAG.getConstant(1, DL, VT);
for (int i = 0; i < 2; i++) {
unsigned Op = OpSign[i] ? ISD::SRA : ISD::SRL;
Operands[i] = DAG.getNode(Op, DL, VT, Ops[i], One);
}
// Compute carry of the 1-bit addition of least significant bits
SDValue Carry;
BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Ops[1]);
if (BV && BV->isConstant()) {
bool AllOdds = true;
bool AllEven = true;
for (SDValue Op : BV->ops()) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
if (C->getAPIntValue()[0])
AllEven = false;
else
AllOdds = false;
} else {
BV = nullptr;
break;
}
}
if (BV && AllOdds)
Carry = One;
if (BV && AllEven)
Carry = DAG.getNode(ISD::AND, DL, VT, Ops[0], One);
}
if (!Carry && Subtarget.hasAVX512())
if (ScalarVT == MVT::i32 || ScalarVT == MVT::i64)
Carry = DAG.getNode(X86ISD::VPTERNLOG, DL, VT, Ops[0], Ops[1], One,
DAG.getTargetConstant(0xA8, DL, MVT::i8));
if (!Carry) {
Carry = DAG.getNode(ISD::OR, DL, VT, Ops);
Carry = DAG.getNode(ISD::AND, DL, VT, Carry, One);
}
Carry = DAG.getNode(ISD::ADD, DL, VT, Carry, Operands[1]);
return DAG.getNode(ISD::ADD, DL, VT, Operands[0], Carry);
};
auto AVGSplitter = [&](SDValue Op0, SDValue Op1) {
@ -45817,12 +45890,15 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
};
// Take care of the case when one of the operands is a constant vector whose
// element is in the range [1, 256].
if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) &&
Operands[0].getOpcode() == ISD::ZERO_EXTEND &&
// element is in the range [1, 256] (for i8). TODO: support negative values.
unsigned SBW = ScalarVT.getSizeInBits();
unsigned OpX = Operands[0].getOpcode();
if (IsConstVectorInRange(Operands[1], 1, SBW < 64 ? 1ull << SBW : -1) &&
(OpX == ISD::ZERO_EXTEND || OpX == ISD::SIGN_EXTEND) &&
Operands[0].getOperand(0).getValueType() == VT) {
// The pattern is detected. Subtract one from the constant vector, then
// demote it and emit X86ISD::AVG instruction.
OpSign[0] = (OpX == ISD::SIGN_EXTEND);
SDValue VecOnes = DAG.getConstant(1, DL, InVT);
Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
@ -45866,10 +45942,12 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
// Check if Operands[0] and Operands[1] are results of type promotion.
for (int j = 0; j < 2; ++j)
if (Operands[j].getValueType() != VT) {
if (Operands[j].getOpcode() != ISD::ZERO_EXTEND ||
unsigned Op = Operands[j].getOpcode();
if ((Op != ISD::ZERO_EXTEND && Op != ISD::SIGN_EXTEND) ||
Operands[j].getOperand(0).getValueType() != VT)
return SDValue();
Operands[j] = Operands[j].getOperand(0);
OpSign[j] = Op == ISD::SIGN_EXTEND;
}
// The pattern is detected, emit X86ISD::AVG instruction(s).