mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-21 18:22:53 +01:00
X86: improve (V)PMADDWD detection (2)
Implement "full" pattern.
This commit is contained in:
parent
610c27aa1c
commit
1cc7bdd501
@ -43550,6 +43550,18 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
|
||||||
|
N1.getOpcode() == ISD::SIGN_EXTEND && N1.hasOneUse() &&
|
||||||
|
N0.getOperand(0).getScalarValueSizeInBits() == 16 &&
|
||||||
|
N1.getOperand(0).getScalarValueSizeInBits() == 16) {
|
||||||
|
// If both arguments are sign-extended, try to replace sign extends
|
||||||
|
// with zero extends, which should qualify for the optimization.
|
||||||
|
// Otherwise just fallback to zero-extension check.
|
||||||
|
Mask17 = 0;
|
||||||
|
N0 = DAG.getNode(ISD::ZERO_EXTEND, N0.getNode(), VT, N0.getOperand(0));
|
||||||
|
N1 = DAG.getNode(ISD::ZERO_EXTEND, N1.getNode(), VT, N1.getOperand(0));
|
||||||
|
}
|
||||||
|
|
||||||
if (!!Mask17 && N0.getOpcode() == ISD::SRA) {
|
if (!!Mask17 && N0.getOpcode() == ISD::SRA) {
|
||||||
if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
|
if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
|
||||||
DAG.ComputeNumSignBits(N1) >= 17 &&
|
DAG.ComputeNumSignBits(N1) >= 17 &&
|
||||||
@ -50186,6 +50198,114 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
|
|||||||
PMADDBuilder);
|
PMADDBuilder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to turn various patterns into PMADDWD when applicable.
|
||||||
|
// (add (mul (...), (...)), (mul (...), (...))
|
||||||
|
static SDValue matchPMADDWD_3(SelectionDAG &DAG, SDValue N0, SDValue N1,
|
||||||
|
const SDLoc &DL, EVT VT,
|
||||||
|
const X86Subtarget &Subtarget) {
|
||||||
|
if (!Subtarget.hasSSE2() || Subtarget.isPMADDWDSlow())
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
// Make sure the type is legal or will be widened to a legal type.
|
||||||
|
if (VT != MVT::v2i32 && !DAG.getTargetLoweringInfo().isTypeLegal(VT))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements());
|
||||||
|
|
||||||
|
// Without BWI, we would need to split v32i16.
|
||||||
|
if (WVT == MVT::v32i16 && !Subtarget.hasBWI())
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
SDValue N00 = N0.getOperand(0);
|
||||||
|
SDValue N01 = N0.getOperand(1);
|
||||||
|
SDValue N10 = N1.getOperand(0);
|
||||||
|
SDValue N11 = N1.getOperand(1);
|
||||||
|
|
||||||
|
APInt Mask17 = APInt::getHighBitsSet(32, 17);
|
||||||
|
if (N00.getOpcode() == ISD::SRA && N01.getOpcode() == ISD::SRA &&
|
||||||
|
N10.getOpcode() == ISD::SRA && N11.getOpcode() == ISD::SRA) {
|
||||||
|
// If both arguments are sign-extended, try to replace sign extends
|
||||||
|
// with zero extends, which should qualify for the optimization.
|
||||||
|
// Otherwise just fallback to zero-extension check.
|
||||||
|
if (isa<ConstantSDNode>(N00.getOperand(1).getOperand(0)) &&
|
||||||
|
isa<ConstantSDNode>(N01.getOperand(1).getOperand(0)) &&
|
||||||
|
isa<ConstantSDNode>(N10.getOperand(1).getOperand(0)) &&
|
||||||
|
isa<ConstantSDNode>(N11.getOperand(1).getOperand(0)) &&
|
||||||
|
N00.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||||
|
N01.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||||
|
N10.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||||
|
N11.getOperand(1).getConstantOperandVal(0) == 16 &&
|
||||||
|
DAG.isSplatValue(N00.getOperand(1)) &&
|
||||||
|
DAG.isSplatValue(N01.getOperand(1)) &&
|
||||||
|
DAG.isSplatValue(N10.getOperand(1)) &&
|
||||||
|
DAG.isSplatValue(N11.getOperand(1))) {
|
||||||
|
|
||||||
|
SDValue S00 = N00.getOperand(0);
|
||||||
|
SDValue S01 = N01.getOperand(0);
|
||||||
|
SDValue S10 = N10.getOperand(0);
|
||||||
|
SDValue S11 = N11.getOperand(0);
|
||||||
|
|
||||||
|
if (S10.getOpcode() == ISD::SHL && S11.getOpcode() == ISD::SHL) {
|
||||||
|
std::swap(S00, S10);
|
||||||
|
std::swap(S01, S11);
|
||||||
|
std::swap(N00, N10);
|
||||||
|
std::swap(N01, N11);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (S00.getOpcode() == ISD::SHL && S01.getOpcode() == ISD::SHL) {
|
||||||
|
if (S00.getOperand(0) == S10 && S01.getOperand(0) == S11) {
|
||||||
|
// Multiplication components are of the same sources
|
||||||
|
Mask17 = 0;
|
||||||
|
N0 = S10;
|
||||||
|
N1 = S11;
|
||||||
|
} else {
|
||||||
|
KnownBits k00, k01, k10, k11;
|
||||||
|
k00 = DAG.computeKnownBits(S00);
|
||||||
|
k01 = DAG.computeKnownBits(S01);
|
||||||
|
k10 = DAG.computeKnownBits(S10);
|
||||||
|
k11 = DAG.computeKnownBits(S11);
|
||||||
|
|
||||||
|
// N00 = N00.getOperand(0);
|
||||||
|
// N01 = N01.getOperand(0);
|
||||||
|
|
||||||
|
// N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
|
||||||
|
// N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Mask17 = 0;
|
||||||
|
N00 = DAG.getNode(ISD::SRL, DL, VT, N00.getOperand(0), N00.getOperand(1));
|
||||||
|
N01 = DAG.getNode(ISD::SRL, DL, VT, N01.getOperand(0), N01.getOperand(1));
|
||||||
|
N10 = DAG.getNode(ISD::AND, DL, VT, N10.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
|
||||||
|
N11 = DAG.getNode(ISD::AND, DL, VT, N11.getOperand(0), DAG.getConstant(0xffff0000u, DL, VT));
|
||||||
|
N0 = DAG.getNode(ISD::OR, DL, VT, N00, N10);
|
||||||
|
N1 = DAG.getNode(ISD::OR, DL, VT, N01, N11);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!!Mask17 && (!DAG.MaskedValueIsZero(N00, Mask17) ||
|
||||||
|
!DAG.MaskedValueIsZero(N01, Mask17) ||
|
||||||
|
!DAG.MaskedValueIsZero(N10, Mask17) ||
|
||||||
|
!DAG.MaskedValueIsZero(N11, Mask17)))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
// Use SplitOpsAndApply to handle AVX splitting.
|
||||||
|
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
|
||||||
|
ArrayRef<SDValue> Ops) {
|
||||||
|
MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
|
||||||
|
return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops);
|
||||||
|
};
|
||||||
|
return SplitOpsAndApply(DAG, Subtarget, DL, VT,
|
||||||
|
{ DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) },
|
||||||
|
PMADDWDBuilder);
|
||||||
|
}
|
||||||
|
|
||||||
/// CMOV of constants requires materializing constant operands in registers.
|
/// CMOV of constants requires materializing constant operands in registers.
|
||||||
/// Try to fold those constants into an 'add' instruction to reduce instruction
|
/// Try to fold those constants into an 'add' instruction to reduce instruction
|
||||||
/// count. We do this with CMOV rather the generic 'select' because there are
|
/// count. We do this with CMOV rather the generic 'select' because there are
|
||||||
@ -50240,6 +50360,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
|
|||||||
return MAdd;
|
return MAdd;
|
||||||
if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
|
if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
|
||||||
return MAdd;
|
return MAdd;
|
||||||
|
if (SDValue MAdd = matchPMADDWD_3(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
|
||||||
|
return MAdd;
|
||||||
|
|
||||||
// Try to synthesize horizontal adds from adds of shuffles.
|
// Try to synthesize horizontal adds from adds of shuffles.
|
||||||
if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))
|
if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))
|
||||||
|
Loading…
Reference in New Issue
Block a user