1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 02:33:06 +01:00

X86: improve (V)PMADDWD detection (2)

Implement "full" pattern.
This commit is contained in:
Nekotekina 2021-11-16 13:50:49 +03:00
parent 610c27aa1c
commit 1cc7bdd501

View File

@ -43550,6 +43550,18 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
}
}
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
N1.getOpcode() == ISD::SIGN_EXTEND && N1.hasOneUse() &&
N0.getOperand(0).getScalarValueSizeInBits() == 16 &&
N1.getOperand(0).getScalarValueSizeInBits() == 16) {
// If both arguments are sign-extended, try to replace sign extends
// with zero extends, which should qualify for the optimization.
// Otherwise just fallback to zero-extension check.
Mask17 = 0;
N0 = DAG.getNode(ISD::ZERO_EXTEND, N0.getNode(), VT, N0.getOperand(0));
N1 = DAG.getNode(ISD::ZERO_EXTEND, N1.getNode(), VT, N1.getOperand(0));
}
if (!!Mask17 && N0.getOpcode() == ISD::SRA) {
if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
DAG.ComputeNumSignBits(N1) >= 17 &&
@ -50186,6 +50198,114 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
PMADDBuilder);
}
// Attempt to turn various patterns into PMADDWD when applicable.
// (add (mul (...), (...)), (mul (...), (...))
static SDValue matchPMADDWD_3(SelectionDAG &DAG, SDValue N0, SDValue N1,
const SDLoc &DL, EVT VT,
const X86Subtarget &Subtarget) {
if (!Subtarget.hasSSE2() || Subtarget.isPMADDWDSlow())
return SDValue();
if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
return SDValue();
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32)
return SDValue();
// Make sure the type is legal or will be widened to a legal type.
if (VT != MVT::v2i32 && !DAG.getTargetLoweringInfo().isTypeLegal(VT))
return SDValue();
MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements());
// Without BWI, we would need to split v32i16.
if (WVT == MVT::v32i16 && !Subtarget.hasBWI())
return SDValue();
SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
SDValue N10 = N1.getOperand(0);
SDValue N11 = N1.getOperand(1);
APInt Mask17 = APInt::getHighBitsSet(32, 17);
if (N00.getOpcode() == ISD::SRA && N01.getOpcode() == ISD::SRA &&
N10.getOpcode() == ISD::SRA && N11.getOpcode() == ISD::SRA) {
// If both arguments are sign-extended, try to replace sign extends
// with zero extends, which should qualify for the optimization.
// Otherwise just fallback to zero-extension check.
if (isa<ConstantSDNode>(N00.getOperand(1).getOperand(0)) &&
isa<ConstantSDNode>(N01.getOperand(1).getOperand(0)) &&
isa<ConstantSDNode>(N10.getOperand(1).getOperand(0)) &&
isa<ConstantSDNode>(N11.getOperand(1).getOperand(0)) &&
N00.getOperand(1).getConstantOperandVal(0) == 16 &&
N01.getOperand(1).getConstantOperandVal(0) == 16 &&
N10.getOperand(1).getConstantOperandVal(0) == 16 &&
N11.getOperand(1).getConstantOperandVal(0) == 16 &&
DAG.isSplatValue(N00.getOperand(1)) &&
DAG.isSplatValue(N01.getOperand(1)) &&
DAG.isSplatValue(N10.getOperand(1)) &&
DAG.isSplatValue(N11.getOperand(1))) {
SDValue S00 = N00.getOperand(0);
SDValue S01 = N01.getOperand(0);
SDValue S10 = N10.getOperand(0);
SDValue S11 = N11.getOperand(0);
if (S10.getOpcode() == ISD::SHL && S11.getOpcode() == ISD::SHL) {
std::swap(S00, S10);
std::swap(S01, S11);
std::swap(N00, N10);
std::swap(N01, N11);
}
if (S00.getOpcode() == ISD::SHL && S01.getOpcode() == ISD::SHL) {
if (S00.getOperand(0) == S10 && S01.getOperand(0) == S11) {
// Multiplication components are of the same sources
Mask17 = 0;
N0 = S10;
N1 = S11;
} else {
KnownBits k00, k01, k10, k11;
k00 = DAG.computeKnownBits(S00);
k01 = DAG.computeKnownBits(S01);
k10 = DAG.computeKnownBits(S10);
k11 = DAG.computeKnownBits(S11);
// N00 = N00.getOperand(0);
// N01 = N01.getOperand(0);
// N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
// N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
}
} else {
Mask17 = 0;
N00 = DAG.getNode(ISD::SRL, DL, VT, N00.getOperand(0), N00.getOperand(1));
N01 = DAG.getNode(ISD::SRL, DL, VT, N01.getOperand(0), N01.getOperand(1));
N10 = DAG.getNode(ISD::AND, DL, VT, N10.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
N11 = DAG.getNode(ISD::AND, DL, VT, N11.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
}
}
}
if (!!Mask17 && (!DAG.MaskedValueIsZero(N00, Mask17) ||
!DAG.MaskedValueIsZero(N01, Mask17) ||
!DAG.MaskedValueIsZero(N10, Mask17) ||
!DAG.MaskedValueIsZero(N11, Mask17)))
return SDValue();
// Use SplitOpsAndApply to handle AVX splitting.
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops);
};
return SplitOpsAndApply(DAG, Subtarget, DL, VT,
{ DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) },
PMADDWDBuilder);
}
/// CMOV of constants requires materializing constant operands in registers.
/// Try to fold those constants into an 'add' instruction to reduce instruction
/// count. We do this with CMOV rather the generic 'select' because there are
@ -50240,6 +50360,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
return MAdd;
if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
return MAdd;
if (SDValue MAdd = matchPMADDWD_3(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
return MAdd;
// Try to synthesize horizontal adds from adds of shuffles.
if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))