1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2025-02-01 13:11:39 +01:00

[X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like most other DAG combines.

Creating new nodes is what we usually do. Have to explicitly
check that we don't update to an existing node and having
to manually manage the worklist is unusual.

We can probably add a helper function to reduce the duplication
of having to check if we should create a gather or scatter, but
I wanted to just get the simple thing done.

llvm-svn: 373137
This commit is contained in:
Craig Topper 2019-09-28 01:08:46 +00:00
parent a379ee1929
commit e3b69f4175

View File

@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
SDValue Chain = GorS->getChain();
SDValue Index = GorS->getIndex();
SDValue Mask = GorS->getMask();
SDValue Base = GorS->getBasePtr();
SDValue Scale = GorS->getScale();
if (DCI.isBeforeLegalizeOps()) {
SDValue Index = N->getOperand(4);
// Remove any sign extends from 32 or smaller to larger than 32.
// Only do this before LegalizeOps in case we need the sign extend for
// legalization.
if (Index.getOpcode() == ISD::SIGN_EXTEND) {
if (Index.getScalarValueSizeInBits() > 32 &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N) {
// The original sign extend has less users, add back to worklist in
// case it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
}
return SDValue(Res, 0);
if (Index.getOpcode() == ISD::SIGN_EXTEND &&
Index.getScalarValueSizeInBits() > 32 &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = Index.getOperand(0);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
}
// Make sure the index is either i32 or i64
@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
Index.getValueType().getVectorNumElements());
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index;
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N)
DCI.AddToWorklist(N);
return SDValue(Res, 0);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
}
// Try to remove zero extends from 32->64 if we know the sign bit of
// the input is zero.
if (Index.getOpcode() == ISD::ZERO_EXTEND &&
Index.getScalarValueSizeInBits() == 64 &&
Index.getOperand(0).getScalarValueSizeInBits() == 32) {
if (DAG.SignBitIsZero(Index.getOperand(0))) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N) {
// The original sign extend has less users, add back to worklist in
// case it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
}
return SDValue(Res, 0);
Index.getOperand(0).getScalarValueSizeInBits() == 32 &&
DAG.SignBitIsZero(Index.getOperand(0))) {
Index = Index.getOperand(0);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
}
}
// With vector masks we only demand the upper bit of the mask.
SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));