mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-21 18:22:53 +01:00
X86: expand detectAVGPattern()
Allow all integer widths in the pattern, allow ashr Handle signed and mixed cases, allowing to replace truncation
This commit is contained in:
parent
5ff8f4151c
commit
2ffa82223f
@ -45742,11 +45742,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
||||
unsigned NumElems = VT.getVectorNumElements();
|
||||
|
||||
EVT ScalarVT = VT.getVectorElementType();
|
||||
if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && NumElems >= 2))
|
||||
if (ScalarVT == MVT::i1 || NumElems < 2)
|
||||
return SDValue();
|
||||
|
||||
// InScalarVT is the intermediate type in AVG pattern and it should be greater
|
||||
// than the original input type (i8/i16).
|
||||
// than the original input type.
|
||||
EVT InScalarVT = InVT.getVectorElementType();
|
||||
if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits())
|
||||
return SDValue();
|
||||
@ -45764,12 +45764,14 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
||||
// %6 = trunc <N x i32> %5 to <N x i8>
|
||||
//
|
||||
// In AVX512, the last instruction can also be a trunc store.
|
||||
if (In.getOpcode() != ISD::SRL)
|
||||
|
||||
// Shift type (lshr or ashr) doesn't affect the result, allow both.
|
||||
if (In.getOpcode() != ISD::SRL && In.getOpcode() != ISD::SRA)
|
||||
return SDValue();
|
||||
|
||||
// A lambda checking the given SDValue is a constant vector and each element
|
||||
// is in the range [Min, Max].
|
||||
auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) {
|
||||
auto IsConstVectorInRange = [](SDValue V, uint64_t Min, uint64_t Max) {
|
||||
return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) {
|
||||
return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max));
|
||||
});
|
||||
@ -45788,9 +45790,80 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
||||
Operands[0] = LHS.getOperand(0);
|
||||
Operands[1] = LHS.getOperand(1);
|
||||
|
||||
auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
|
||||
ArrayRef<SDValue> Ops) {
|
||||
return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
|
||||
// Signed-ness of Operands[0..1] (true if sign-extended)
|
||||
bool OpSign[2]{false, false};
|
||||
|
||||
auto AVGBuilder = [&](SelectionDAG &DAG, const SDLoc &DL,
|
||||
ArrayRef<SDValue> Ops) {
|
||||
EVT VT = Ops[0].getValueType();
|
||||
EVT ScalarVT = VT.getVectorElementType();
|
||||
// Legal X86ISD::AVG types
|
||||
if (ScalarVT == MVT::i8 || ScalarVT == MVT::i16) {
|
||||
// Emulate signed or mixed AVG op via unsigned AVG
|
||||
if (OpSign[0] || OpSign[1]) {
|
||||
unsigned SignM = ScalarVT == MVT::i8 ? 0x80 : 0x8000;
|
||||
SDValue SignVal = DAG.getConstant(SignM, DL, VT);
|
||||
SDValue Op0 = Ops[0];
|
||||
SDValue Op1 = Ops[1];
|
||||
if (OpSign[0])
|
||||
Op0 = DAG.getNode(ISD::XOR, DL, VT, Ops[0], SignVal);
|
||||
if (OpSign[1])
|
||||
Op1 = DAG.getNode(ISD::XOR, DL, VT, Ops[1], SignVal);
|
||||
SDValue R = DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1);
|
||||
|
||||
// Signed AVG
|
||||
if (OpSign[0] && OpSign[1])
|
||||
return DAG.getNode(ISD::XOR, DL, VT, R, SignVal);
|
||||
|
||||
// Mixed signed-unsigned AVG
|
||||
SignVal = DAG.getConstant(SignM >> 1, DL, VT);
|
||||
return DAG.getNode(ISD::SUB, DL, VT, R, SignVal);
|
||||
}
|
||||
|
||||
// Standard unsigned AVG
|
||||
return DAG.getNode(X86ISD::AVG, DL, VT, Ops);
|
||||
}
|
||||
|
||||
// Emulate AVG with more effective in-lane in-type algorithm
|
||||
SDValue Operands[2];
|
||||
SDValue One = DAG.getConstant(1, DL, VT);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
unsigned Op = OpSign[i] ? ISD::SRA : ISD::SRL;
|
||||
Operands[i] = DAG.getNode(Op, DL, VT, Ops[i], One);
|
||||
}
|
||||
|
||||
// Compute carry of the 1-bit addition of least significant bits
|
||||
SDValue Carry;
|
||||
BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Ops[1]);
|
||||
if (BV && BV->isConstant()) {
|
||||
bool AllOdds = true;
|
||||
bool AllEven = true;
|
||||
for (SDValue Op : BV->ops()) {
|
||||
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
|
||||
if (C->getAPIntValue()[0])
|
||||
AllEven = false;
|
||||
else
|
||||
AllOdds = false;
|
||||
} else {
|
||||
BV = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (BV && AllOdds)
|
||||
Carry = One;
|
||||
if (BV && AllEven)
|
||||
Carry = DAG.getNode(ISD::AND, DL, VT, Ops[0], One);
|
||||
}
|
||||
if (!Carry && Subtarget.hasAVX512())
|
||||
if (ScalarVT == MVT::i32 || ScalarVT == MVT::i64)
|
||||
Carry = DAG.getNode(X86ISD::VPTERNLOG, DL, VT, Ops[0], Ops[1], One,
|
||||
DAG.getTargetConstant(0xA8, DL, MVT::i8));
|
||||
if (!Carry) {
|
||||
Carry = DAG.getNode(ISD::OR, DL, VT, Ops);
|
||||
Carry = DAG.getNode(ISD::AND, DL, VT, Carry, One);
|
||||
}
|
||||
Carry = DAG.getNode(ISD::ADD, DL, VT, Carry, Operands[1]);
|
||||
return DAG.getNode(ISD::ADD, DL, VT, Operands[0], Carry);
|
||||
};
|
||||
|
||||
auto AVGSplitter = [&](SDValue Op0, SDValue Op1) {
|
||||
@ -45817,12 +45890,15 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
||||
};
|
||||
|
||||
// Take care of the case when one of the operands is a constant vector whose
|
||||
// element is in the range [1, 256].
|
||||
if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) &&
|
||||
Operands[0].getOpcode() == ISD::ZERO_EXTEND &&
|
||||
// element is in the range [1, 256] (for i8). TODO: support negative values.
|
||||
unsigned SBW = ScalarVT.getSizeInBits();
|
||||
unsigned OpX = Operands[0].getOpcode();
|
||||
if (IsConstVectorInRange(Operands[1], 1, SBW < 64 ? 1ull << SBW : -1) &&
|
||||
(OpX == ISD::ZERO_EXTEND || OpX == ISD::SIGN_EXTEND) &&
|
||||
Operands[0].getOperand(0).getValueType() == VT) {
|
||||
// The pattern is detected. Subtract one from the constant vector, then
|
||||
// demote it and emit X86ISD::AVG instruction.
|
||||
OpSign[0] = (OpX == ISD::SIGN_EXTEND);
|
||||
SDValue VecOnes = DAG.getConstant(1, DL, InVT);
|
||||
Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
|
||||
Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
|
||||
@ -45866,10 +45942,12 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
||||
// Check if Operands[0] and Operands[1] are results of type promotion.
|
||||
for (int j = 0; j < 2; ++j)
|
||||
if (Operands[j].getValueType() != VT) {
|
||||
if (Operands[j].getOpcode() != ISD::ZERO_EXTEND ||
|
||||
unsigned Op = Operands[j].getOpcode();
|
||||
if ((Op != ISD::ZERO_EXTEND && Op != ISD::SIGN_EXTEND) ||
|
||||
Operands[j].getOperand(0).getValueType() != VT)
|
||||
return SDValue();
|
||||
Operands[j] = Operands[j].getOperand(0);
|
||||
OpSign[j] = Op == ISD::SIGN_EXTEND;
|
||||
}
|
||||
|
||||
// The pattern is detected, emit X86ISD::AVG instruction(s).
|
||||
|
Loading…
Reference in New Issue
Block a user