mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-23 03:02:36 +01:00
[SVE][CodeGen] Remove performMaskedGatherScatterCombine
The AArch64 DAG combine added by D90945 & D91433 extends the index of a scalable masked gather or scatter to i32 if necessary. This patch removes the combine and instead adds shouldExtendGSIndex, which is used by visitMaskedGather/Scatter in SelectionDAGBuilder to query whether the index should be extended before calling getMaskedGather/Scatter. Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D94525
This commit is contained in:
parent
5ff3260e34
commit
267255edd9
@ -1318,6 +1318,10 @@ public:
|
||||
getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom);
|
||||
}
|
||||
|
||||
/// Returns true if the index type for a masked gather/scatter requires
|
||||
/// extending
|
||||
virtual bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const { return false; }
|
||||
|
||||
// Returns true if VT is a legal index type for masked gathers/scatters
|
||||
// on this target
|
||||
virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; }
|
||||
|
@ -4339,6 +4339,14 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
|
||||
IndexType = ISD::SIGNED_UNSCALED;
|
||||
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
|
||||
}
|
||||
|
||||
EVT IdxVT = Index.getValueType();
|
||||
EVT EltTy = IdxVT.getVectorElementType();
|
||||
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
|
||||
EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
|
||||
Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
|
||||
}
|
||||
|
||||
SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
|
||||
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
|
||||
Ops, MMO, IndexType, false);
|
||||
@ -4450,6 +4458,14 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
|
||||
IndexType = ISD::SIGNED_UNSCALED;
|
||||
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
|
||||
}
|
||||
|
||||
EVT IdxVT = Index.getValueType();
|
||||
EVT EltTy = IdxVT.getVectorElementType();
|
||||
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
|
||||
EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
|
||||
Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
|
||||
}
|
||||
|
||||
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
|
||||
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
|
||||
Ops, MMO, IndexType, ISD::NON_EXTLOAD);
|
||||
|
@ -873,9 +873,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
||||
if (Subtarget->supportsAddressTopByteIgnored())
|
||||
setTargetDAGCombine(ISD::LOAD);
|
||||
|
||||
setTargetDAGCombine(ISD::MGATHER);
|
||||
setTargetDAGCombine(ISD::MSCATTER);
|
||||
|
||||
setTargetDAGCombine(ISD::MUL);
|
||||
|
||||
setTargetDAGCombine(ISD::SELECT);
|
||||
@ -3825,6 +3822,15 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
|
||||
}
|
||||
}
|
||||
|
||||
bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
|
||||
if (VT.getVectorElementType() == MVT::i8 ||
|
||||
VT.getVectorElementType() == MVT::i16) {
|
||||
EltTy = MVT::i32;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
|
||||
if (VT.getVectorElementType() == MVT::i32 &&
|
||||
VT.getVectorElementCount().getKnownMinValue() >= 4)
|
||||
@ -14395,55 +14401,6 @@ static SDValue performSTORECombine(SDNode *N,
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
static SDValue performMaskedGatherScatterCombine(SDNode *N,
|
||||
TargetLowering::DAGCombinerInfo &DCI,
|
||||
SelectionDAG &DAG) {
|
||||
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
|
||||
assert(MGS && "Can only combine gather load or scatter store nodes");
|
||||
|
||||
SDLoc DL(MGS);
|
||||
SDValue Chain = MGS->getChain();
|
||||
SDValue Scale = MGS->getScale();
|
||||
SDValue Index = MGS->getIndex();
|
||||
SDValue Mask = MGS->getMask();
|
||||
SDValue BasePtr = MGS->getBasePtr();
|
||||
ISD::MemIndexType IndexType = MGS->getIndexType();
|
||||
|
||||
EVT IdxVT = Index.getValueType();
|
||||
|
||||
if (DCI.isBeforeLegalize()) {
|
||||
// SVE gather/scatter requires indices of i32/i64. Promote anything smaller
|
||||
// prior to legalisation so the result can be split if required.
|
||||
if ((IdxVT.getVectorElementType() == MVT::i8) ||
|
||||
(IdxVT.getVectorElementType() == MVT::i16)) {
|
||||
EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
|
||||
if (MGS->isIndexSigned())
|
||||
Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
|
||||
else
|
||||
Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
|
||||
|
||||
if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
|
||||
SDValue PassThru = MGT->getPassThru();
|
||||
SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
|
||||
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
|
||||
PassThru.getValueType(), DL, Ops,
|
||||
MGT->getMemOperand(),
|
||||
MGT->getIndexType(), MGT->getExtensionType());
|
||||
} else {
|
||||
auto *MSC = cast<MaskedScatterSDNode>(MGS);
|
||||
SDValue Data = MSC->getValue();
|
||||
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
|
||||
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
|
||||
MSC->getMemoryVT(), DL, Ops,
|
||||
MSC->getMemOperand(), IndexType,
|
||||
MSC->isTruncatingStore());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
/// Target-specific DAG combine function for NEON load/store intrinsics
|
||||
/// to merge base address updates.
|
||||
static SDValue performNEONPostLDSTCombine(SDNode *N,
|
||||
@ -15638,9 +15595,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
|
||||
break;
|
||||
case ISD::STORE:
|
||||
return performSTORECombine(N, DCI, DAG, Subtarget);
|
||||
case ISD::MGATHER:
|
||||
case ISD::MSCATTER:
|
||||
return performMaskedGatherScatterCombine(N, DCI, DAG);
|
||||
case AArch64ISD::BRCOND:
|
||||
return performBRCONDCombine(N, DCI, DAG);
|
||||
case AArch64ISD::TBNZ:
|
||||
|
@ -996,6 +996,7 @@ private:
|
||||
return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
|
||||
}
|
||||
|
||||
bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override;
|
||||
bool shouldRemoveExtendFromGSIndex(EVT VT) const override;
|
||||
bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
|
||||
bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;
|
||||
|
Loading…
Reference in New Issue
Block a user