mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-21 18:22:53 +01:00
X86: improve (V)PMADDWD detection (2)
Implement "full" pattern.
This commit is contained in:
parent
610c27aa1c
commit
1cc7bdd501
@ -43550,6 +43550,18 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
|
||||
}
|
||||
}
|
||||
|
||||
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
|
||||
N1.getOpcode() == ISD::SIGN_EXTEND && N1.hasOneUse() &&
|
||||
N0.getOperand(0).getScalarValueSizeInBits() == 16 &&
|
||||
N1.getOperand(0).getScalarValueSizeInBits() == 16) {
|
||||
// If both arguments are sign-extended, try to replace sign extends
|
||||
// with zero extends, which should qualify for the optimization.
|
||||
// Otherwise just fallback to zero-extension check.
|
||||
Mask17 = 0;
|
||||
N0 = DAG.getNode(ISD::ZERO_EXTEND, N0.getNode(), VT, N0.getOperand(0));
|
||||
N1 = DAG.getNode(ISD::ZERO_EXTEND, N1.getNode(), VT, N1.getOperand(0));
|
||||
}
|
||||
|
||||
if (!!Mask17 && N0.getOpcode() == ISD::SRA) {
|
||||
if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
|
||||
DAG.ComputeNumSignBits(N1) >= 17 &&
|
||||
@ -50186,6 +50198,114 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
|
||||
PMADDBuilder);
|
||||
}
|
||||
|
||||
// Attempt to turn various patterns into PMADDWD when applicable.
|
||||
// (add (mul (...), (...)), (mul (...), (...))
|
||||
static SDValue matchPMADDWD_3(SelectionDAG &DAG, SDValue N0, SDValue N1,
|
||||
const SDLoc &DL, EVT VT,
|
||||
const X86Subtarget &Subtarget) {
|
||||
if (!Subtarget.hasSSE2() || Subtarget.isPMADDWDSlow())
|
||||
return SDValue();
|
||||
|
||||
if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
|
||||
return SDValue();
|
||||
|
||||
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32)
|
||||
return SDValue();
|
||||
|
||||
// Make sure the type is legal or will be widened to a legal type.
|
||||
if (VT != MVT::v2i32 && !DAG.getTargetLoweringInfo().isTypeLegal(VT))
|
||||
return SDValue();
|
||||
|
||||
MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements());
|
||||
|
||||
// Without BWI, we would need to split v32i16.
|
||||
if (WVT == MVT::v32i16 && !Subtarget.hasBWI())
|
||||
return SDValue();
|
||||
|
||||
SDValue N00 = N0.getOperand(0);
|
||||
SDValue N01 = N0.getOperand(1);
|
||||
SDValue N10 = N1.getOperand(0);
|
||||
SDValue N11 = N1.getOperand(1);
|
||||
|
||||
APInt Mask17 = APInt::getHighBitsSet(32, 17);
|
||||
if (N00.getOpcode() == ISD::SRA && N01.getOpcode() == ISD::SRA &&
|
||||
N10.getOpcode() == ISD::SRA && N11.getOpcode() == ISD::SRA) {
|
||||
// If both arguments are sign-extended, try to replace sign extends
|
||||
// with zero extends, which should qualify for the optimization.
|
||||
// Otherwise just fallback to zero-extension check.
|
||||
if (isa<ConstantSDNode>(N00.getOperand(1).getOperand(0)) &&
|
||||
isa<ConstantSDNode>(N01.getOperand(1).getOperand(0)) &&
|
||||
isa<ConstantSDNode>(N10.getOperand(1).getOperand(0)) &&
|
||||
isa<ConstantSDNode>(N11.getOperand(1).getOperand(0)) &&
|
||||
N00.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||
N01.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||
N10.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||
N11.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||
DAG.isSplatValue(N00.getOperand(1)) &&
|
||||
DAG.isSplatValue(N01.getOperand(1)) &&
|
||||
DAG.isSplatValue(N10.getOperand(1)) &&
|
||||
DAG.isSplatValue(N11.getOperand(1))) {
|
||||
|
||||
SDValue S00 = N00.getOperand(0);
|
||||
SDValue S01 = N01.getOperand(0);
|
||||
SDValue S10 = N10.getOperand(0);
|
||||
SDValue S11 = N11.getOperand(0);
|
||||
|
||||
if (S10.getOpcode() == ISD::SHL && S11.getOpcode() == ISD::SHL) {
|
||||
std::swap(S00, S10);
|
||||
std::swap(S01, S11);
|
||||
std::swap(N00, N10);
|
||||
std::swap(N01, N11);
|
||||
}
|
||||
|
||||
if (S00.getOpcode() == ISD::SHL && S01.getOpcode() == ISD::SHL) {
|
||||
if (S00.getOperand(0) == S10 && S01.getOperand(0) == S11) {
|
||||
// Multiplication components are of the same sources
|
||||
Mask17 = 0;
|
||||
N0 = S10;
|
||||
N1 = S11;
|
||||
} else {
|
||||
KnownBits k00, k01, k10, k11;
|
||||
k00 = DAG.computeKnownBits(S00);
|
||||
k01 = DAG.computeKnownBits(S01);
|
||||
k10 = DAG.computeKnownBits(S10);
|
||||
k11 = DAG.computeKnownBits(S11);
|
||||
|
||||
// N00 = N00.getOperand(0);
|
||||
// N01 = N01.getOperand(0);
|
||||
|
||||
// N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
|
||||
// N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
|
||||
}
|
||||
} else {
|
||||
Mask17 = 0;
|
||||
N00 = DAG.getNode(ISD::SRL, DL, VT, N00.getOperand(0), N00.getOperand(1));
|
||||
N01 = DAG.getNode(ISD::SRL, DL, VT, N01.getOperand(0), N01.getOperand(1));
|
||||
N10 = DAG.getNode(ISD::AND, DL, VT, N10.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
|
||||
N11 = DAG.getNode(ISD::AND, DL, VT, N11.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
|
||||
N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
|
||||
N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!!Mask17 && (!DAG.MaskedValueIsZero(N00, Mask17) ||
|
||||
!DAG.MaskedValueIsZero(N01, Mask17) ||
|
||||
!DAG.MaskedValueIsZero(N10, Mask17) ||
|
||||
!DAG.MaskedValueIsZero(N11, Mask17)))
|
||||
return SDValue();
|
||||
|
||||
// Use SplitOpsAndApply to handle AVX splitting.
|
||||
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
|
||||
ArrayRef<SDValue> Ops) {
|
||||
MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
|
||||
return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops);
|
||||
};
|
||||
return SplitOpsAndApply(DAG, Subtarget, DL, VT,
|
||||
{ DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) },
|
||||
PMADDWDBuilder);
|
||||
}
|
||||
|
||||
/// CMOV of constants requires materializing constant operands in registers.
|
||||
/// Try to fold those constants into an 'add' instruction to reduce instruction
|
||||
/// count. We do this with CMOV rather the generic 'select' because there are
|
||||
@ -50240,6 +50360,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
|
||||
return MAdd;
|
||||
if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
|
||||
return MAdd;
|
||||
if (SDValue MAdd = matchPMADDWD_3(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
|
||||
return MAdd;
|
||||
|
||||
// Try to synthesize horizontal adds from adds of shuffles.
|
||||
if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))
|
||||
|
Loading…
Reference in New Issue
Block a user