From d809395fde3ac00d9c163c8ed265114e1c8a202e Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Tue, 27 Jul 2021 16:47:52 +0200 Subject: [PATCH] Revert "Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."" This reverts commit d7bbb1230a94cb239aa4a8cb896c45571444675d. There were follow up uses of a deleted method and I didn't run the tests. Undo the revert, so I can do it properly. --- lib/Target/X86/X86ISelLowering.cpp | 107 +++++++++++++---------------- 1 file changed, 48 insertions(+), 59 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 067b56e205e..344bf73b2c5 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -7988,6 +7988,30 @@ static bool getTargetShuffleInputs(SDValue Op, SmallVectorImpl &Inputs, KnownZero, DAG, Depth, ResolveKnownElts); } +// Attempt to create a scalar/subvector broadcast from the base MemSDNode. +static SDValue getBROADCAST_LOAD(unsigned Opcode, const SDLoc &DL, EVT VT, + EVT MemVT, MemSDNode *Mem, unsigned Offset, + SelectionDAG &DAG) { + assert((Opcode == X86ISD::VBROADCAST_LOAD || + Opcode == X86ISD::SUBV_BROADCAST_LOAD) && + "Unknown broadcast load type"); + + // Ensure this is a simple (non-atomic, non-voltile), temporal read memop. + if (!Mem || !Mem->readMem() || !Mem->isSimple() || Mem->isNonTemporal()) + return SDValue(); + + SDValue Ptr = + DAG.getMemBasePlusOffset(Mem->getBasePtr(), TypeSize::Fixed(Offset), DL); + SDVTList Tys = DAG.getVTList(VT, MVT::Other); + SDValue Ops[] = {Mem->getChain(), Ptr}; + SDValue BcstLd = DAG.getMemIntrinsicNode( + Opcode, DL, Tys, Ops, MemVT, + DAG.getMachineFunction().getMachineMemOperand( + Mem->getMemOperand(), Offset, MemVT.getStoreSize())); + DAG.makeEquivalentMemoryOrdering(SDValue(Mem, 1), BcstLd.getValue(1)); + return BcstLd; +} + /// Returns the scalar element that will make up the i'th /// element of the result of the vector shuffle. static SDValue getShuffleScalarElt(SDValue Op, unsigned Index, @@ -16060,21 +16084,12 @@ static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1, bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1); if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() && MayFoldLoad(peekThroughOneUseBitcasts(V1))) { + MVT MemVT = VT.getHalfNumVectorElementsVT(); + unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize(); auto *Ld = cast(peekThroughOneUseBitcasts(V1)); - if (!Ld->isNonTemporal()) { - MVT MemVT = VT.getHalfNumVectorElementsVT(); - unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize(); - SDVTList Tys = DAG.getVTList(VT, MVT::Other); - SDValue Ptr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), - TypeSize::Fixed(Ofs), DL); - SDValue Ops[] = {Ld->getChain(), Ptr}; - SDValue BcastLd = DAG.getMemIntrinsicNode( - X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, MemVT, - DAG.getMachineFunction().getMachineMemOperand( - Ld->getMemOperand(), Ofs, MemVT.getStoreSize())); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1)); - return BcastLd; - } + if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL, + VT, MemVT, Ld, Ofs, DAG)) + return BcstLd; } // With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding. @@ -38977,10 +38992,10 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( } // Subvector broadcast. case X86ISD::SUBV_BROADCAST_LOAD: { + SDLoc DL(Op); auto *MemIntr = cast(Op); EVT MemVT = MemIntr->getMemoryVT(); if (ExtSizeInBits == MemVT.getStoreSizeInBits()) { - SDLoc DL(Op); SDValue Ld = TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(), MemIntr->getBasePtr(), MemIntr->getMemOperand()); @@ -38989,18 +39004,13 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0, TLO.DAG, DL, ExtSizeInBits)); } else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) { - SDLoc DL(Op); EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(), ExtSizeInBits / VT.getScalarSizeInBits()); - SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other); - SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)}; - SDValue Bcst = - TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, - Ops, MemVT, MemIntr->getMemOperand()); - TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1), - Bcst.getValue(1)); - return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0, - TLO.DAG, DL, ExtSizeInBits)); + if (SDValue BcstLd = + getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG)) + return TLO.CombineTo(Op, + insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0, + TLO.DAG, DL, ExtSizeInBits)); } break; } @@ -50073,36 +50083,21 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, if (Op0.getOpcode() == X86ISD::VBROADCAST) return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0)); - // If this scalar/subvector broadcast_load is inserted into both halves, use - // a larger broadcast_load. Update other uses to use an extracted subvector. - if (Op0.getOpcode() == X86ISD::VBROADCAST_LOAD || + // If this simple subvector or scalar/subvector broadcast_load is inserted + // into both halves, use a larger broadcast_load. Update other uses to use + // an extracted subvector. + if (Op0.getOpcode() == ISD::LOAD || + Op0.getOpcode() == X86ISD::VBROADCAST_LOAD || Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) { - auto *MemIntr = cast(Op0); - SDVTList Tys = DAG.getVTList(VT, MVT::Other); - SDValue Ops[] = {MemIntr->getChain(), MemIntr->getBasePtr()}; - SDValue BcastLd = DAG.getMemIntrinsicNode(Op0.getOpcode(), DL, Tys, Ops, - MemIntr->getMemoryVT(), - MemIntr->getMemOperand()); - DAG.ReplaceAllUsesOfValueWith( - Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits())); - DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1)); - return BcastLd; - } - - // If this is a simple subvector load repeated across multiple lanes, then - // broadcast the load. Update other uses to use an extracted subvector. - if (auto *Ld = dyn_cast(Op0)) { - if (Ld->isSimple() && !Ld->isNonTemporal() && - Ld->getExtensionType() == ISD::NON_EXTLOAD) { - SDVTList Tys = DAG.getVTList(VT, MVT::Other); - SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()}; - SDValue BcastLd = - DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, - Ld->getMemoryVT(), Ld->getMemOperand()); + auto *Mem = cast(Op0); + unsigned Opcode = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD + ? X86ISD::VBROADCAST_LOAD + : X86ISD::SUBV_BROADCAST_LOAD; + if (SDValue BcastLd = getBROADCAST_LOAD( + Opcode, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) { DAG.ReplaceAllUsesOfValueWith( Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits())); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1)); return BcastLd; } } @@ -50466,14 +50461,8 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() && SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) { auto *MemIntr = cast(SubVec); - SDVTList Tys = DAG.getVTList(OpVT, MVT::Other); - SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() }; - SDValue BcastLd = - DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, - MemIntr->getMemoryVT(), - MemIntr->getMemOperand()); - DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1)); - return BcastLd; + return getBROADCAST_LOAD(X86ISD::VBROADCAST_LOAD, dl, OpVT, + MemIntr->getMemoryVT(), MemIntr, 0, DAG); } // If we're splatting the lower half subvector of a full vector load into the