1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-23 03:02:36 +01:00

[SVE][CodeGen] Remove performMaskedGatherScatterCombine

The AArch64 DAG combine added by D90945 & D91433 extends the index
of a scalable masked gather or scatter to i32 if necessary.

This patch removes the combine and instead adds shouldExtendGSIndex, which
is used by visitMaskedGather/Scatter in SelectionDAGBuilder to query whether
the index should be extended before calling getMaskedGather/Scatter.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D94525
This commit is contained in:
Kerry McLaughlin 2021-02-01 11:04:36 +00:00
parent 5ff3260e34
commit 267255edd9
4 changed files with 30 additions and 55 deletions

View File

@ -1318,6 +1318,10 @@ public:
getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom);
}
/// Returns true if the index type for a masked gather/scatter requires
/// extending
virtual bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const { return false; }
// Returns true if VT is a legal index type for masked gathers/scatters
// on this target
virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; }

View File

@ -4339,6 +4339,14 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
IndexType = ISD::SIGNED_UNSCALED;
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
}
SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
Ops, MMO, IndexType, false);
@ -4450,6 +4458,14 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
IndexType = ISD::SIGNED_UNSCALED;
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
}
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
Ops, MMO, IndexType, ISD::NON_EXTLOAD);

View File

@ -873,9 +873,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
setTargetDAGCombine(ISD::MGATHER);
setTargetDAGCombine(ISD::MSCATTER);
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SELECT);
@ -3825,6 +3822,15 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
}
bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
if (VT.getVectorElementType() == MVT::i8 ||
VT.getVectorElementType() == MVT::i16) {
EltTy = MVT::i32;
return true;
}
return false;
}
bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
if (VT.getVectorElementType() == MVT::i32 &&
VT.getVectorElementCount().getKnownMinValue() >= 4)
@ -14395,55 +14401,6 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
static SDValue performMaskedGatherScatterCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
assert(MGS && "Can only combine gather load or scatter store nodes");
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
SDValue Index = MGS->getIndex();
SDValue Mask = MGS->getMask();
SDValue BasePtr = MGS->getBasePtr();
ISD::MemIndexType IndexType = MGS->getIndexType();
EVT IdxVT = Index.getValueType();
if (DCI.isBeforeLegalize()) {
// SVE gather/scatter requires indices of i32/i64. Promote anything smaller
// prior to legalisation so the result can be split if required.
if ((IdxVT.getVectorElementType() == MVT::i8) ||
(IdxVT.getVectorElementType() == MVT::i16)) {
EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
if (MGS->isIndexSigned())
Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
else
Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
SDValue PassThru = MGT->getPassThru();
SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemOperand(),
MGT->getIndexType(), MGT->getExtensionType());
} else {
auto *MSC = cast<MaskedScatterSDNode>(MGS);
SDValue Data = MSC->getValue();
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), IndexType,
MSC->isTruncatingStore());
}
}
}
return SDValue();
}
/// Target-specific DAG combine function for NEON load/store intrinsics
/// to merge base address updates.
static SDValue performNEONPostLDSTCombine(SDNode *N,
@ -15638,9 +15595,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER:
return performMaskedGatherScatterCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:

View File

@ -996,6 +996,7 @@ private:
return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
}
bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override;
bool shouldRemoveExtendFromGSIndex(EVT VT) const override;
bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;