1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2025-01-31 20:51:52 +01:00

[RISCV] Add support for VECTOR_REVERSE for scalable vector types.

I've left mask registers to a future patch as we'll need
to convert them to full vectors, shuffle, and then truncate.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D97609
This commit is contained in:
Craig Topper 2021-03-09 09:43:08 -08:00
parent f9107e0902
commit 48f016a681
5 changed files with 1279 additions and 3 deletions

View File

@ -472,6 +472,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);
}
// Expand various CCs to best match the RVV ISA, which natively supports UNE
@ -509,6 +511,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);
};
if (Subtarget.hasStdExtZfh())
@ -1528,6 +1532,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerINSERT_SUBVECTOR(Op, DAG);
case ISD::EXTRACT_SUBVECTOR:
return lowerEXTRACT_SUBVECTOR(Op, DAG);
case ISD::VECTOR_REVERSE:
return lowerVECTOR_REVERSE(Op, DAG);
case ISD::BUILD_VECTOR:
return lowerBUILD_VECTOR(Op, DAG, Subtarget);
case ISD::VECTOR_SHUFFLE:
@ -2793,6 +2799,84 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
return DAG.getBitcast(Op.getSimpleValueType(), Slidedown);
}
// Implement vector_reverse using vrgather.vv with indices determined by
// subtracting the id of each element from (VLMAX-1). This will convert
// the indices like so:
// (0, 1,..., VLMAX-2, VLMAX-1) -> (VLMAX-1, VLMAX-2,..., 1, 0).
// TODO: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16.
SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MVT VecVT = Op.getSimpleValueType();
unsigned EltSize = VecVT.getScalarSizeInBits();
unsigned MinSize = VecVT.getSizeInBits().getKnownMinValue();
unsigned MaxVLMAX = 0;
unsigned VectorBitsMax = Subtarget.getMaxRVVVectorSizeInBits();
if (VectorBitsMax != 0)
MaxVLMAX = ((VectorBitsMax / EltSize) * MinSize) / RISCV::RVVBitsPerBlock;
unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL;
MVT IntVT = VecVT.changeVectorElementTypeToInteger();
// If this is SEW=8 and VLMAX is unknown or more than 256, we need
// to use vrgatherei16.vv.
// TODO: It's also possible to use vrgatherei16.vv for other types to
// decrease register width for the index calculation.
if ((MaxVLMAX == 0 || MaxVLMAX > 256) && EltSize == 8) {
// If this is LMUL=8, we have to split before can use vrgatherei16.vv.
// Reverse each half, then reassemble them in reverse order.
// NOTE: It's also possible that after splitting that VLMAX no longer
// requires vrgatherei16.vv.
if (MinSize == (8 * RISCV::RVVBitsPerBlock)) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(Op.getNode(), 0);
EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VecVT);
Lo = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo);
Hi = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi);
// Reassemble the low and high pieces reversed.
// FIXME: This is a CONCAT_VECTORS.
SDValue Res =
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, DAG.getUNDEF(VecVT), Hi,
DAG.getIntPtrConstant(0, DL));
return DAG.getNode(
ISD::INSERT_SUBVECTOR, DL, VecVT, Res, Lo,
DAG.getIntPtrConstant(LoVT.getVectorMinNumElements(), DL));
}
// Just promote the int type to i16 which will double the LMUL.
IntVT = MVT::getVectorVT(MVT::i16, VecVT.getVectorElementCount());
GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
}
MVT XLenVT = Subtarget.getXLenVT();
SDValue Mask, VL;
std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
// Calculate VLMAX-1 for the desired SEW.
unsigned MinElts = VecVT.getVectorMinNumElements();
SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
DAG.getConstant(MinElts, DL, XLenVT));
SDValue VLMinus1 =
DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DAG.getConstant(1, DL, XLenVT));
// Splat VLMAX-1 taking care to handle SEW==64 on RV32.
bool IsRV32E64 =
!Subtarget.is64Bit() && IntVT.getVectorElementType() == MVT::i64;
SDValue SplatVL;
if (!IsRV32E64)
SplatVL = DAG.getSplatVector(IntVT, DL, VLMinus1);
else
SplatVL = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, IntVT, VLMinus1);
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL);
SDValue Indices =
DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL);
return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, VL);
}
SDValue
RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
SelectionDAG &DAG) const {
@ -5907,6 +5991,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VMCLR_VL)
NODE_NAME_CASE(VMSET_VL)
NODE_NAME_CASE(VRGATHER_VX_VL)
NODE_NAME_CASE(VRGATHER_VV_VL)
NODE_NAME_CASE(VRGATHEREI16_VV_VL)
NODE_NAME_CASE(VSEXT_VL)
NODE_NAME_CASE(VZEXT_VL)
NODE_NAME_CASE(VLE_VL)

View File

@ -203,8 +203,11 @@ enum NodeType : unsigned {
VMCLR_VL,
VMSET_VL,
// Matches the semantics of vrgather.vx with an extra operand for VL.
// Matches the semantics of vrgather.vx and vrgather.vv with an extra operand
// for VL.
VRGATHER_VX_VL,
VRGATHER_VV_VL,
VRGATHEREI16_VV_VL,
// Vector sign/zero extend with additional mask & VL operands.
VSEXT_VL,
@ -446,6 +449,7 @@ private:
SDValue lowerFPVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerINSERT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;

View File

@ -142,6 +142,24 @@ def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL",
SDTCVecEltisVT<3, i1>,
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;
def riscv_vrgather_vv_vl : SDNode<"RISCVISD::VRGATHER_VV_VL",
SDTypeProfile<1, 4, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
SDTCisInt<2>,
SDTCisSameNumEltsAs<0, 2>,
SDTCisSameSizeAs<0, 2>,
SDTCVecEltisVT<3, i1>,
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;
def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
SDTypeProfile<1, 4, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
SDTCisInt<2>,
SDTCVecEltisVT<2, i16>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<3, i1>,
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL",
SDTypeProfile<1, 4, [SDTCisVec<0>,
@ -995,6 +1013,12 @@ foreach vti = AllIntegerVectors in {
(!cast<Instruction>("PseudoVMV_S_X_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.SEW)>;
def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
(vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.SEW)>;
def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
@ -1005,6 +1029,22 @@ foreach vti = AllIntegerVectors in {
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.SEW)>;
// emul = lmul * 16 / sew
defvar vlmul = vti.LMul;
defvar octuple_lmul = octuple_from_str<vlmul.MX>.ret;
defvar octuple_emul = !srl(!mul(octuple_lmul, 16), shift_amount<vti.SEW>.val);
if !and(!ge(octuple_emul, 1), !le(octuple_emul, 64)) then {
defvar emul_str = octuple_to_str<octuple_emul>.ret;
defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
(ivti.Vector ivti.RegClass:$rs1),
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>(inst)
vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.SEW)>;
}
}
} // Predicates = [HasStdExtV]
@ -1019,6 +1059,13 @@ foreach vti = AllFloatVectors in {
(!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#"_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.SEW)>;
defvar ivti = GetIntVTypeInfo<vti>.Vti;
def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
(ivti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.SEW)>;
def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
@ -1029,6 +1076,21 @@ foreach vti = AllFloatVectors in {
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.SEW)>;
defvar vlmul = vti.LMul;
defvar octuple_lmul = octuple_from_str<vlmul.MX>.ret;
defvar octuple_emul = !srl(!mul(octuple_lmul, 16), shift_amount<vti.SEW>.val);
if !and(!ge(octuple_emul, 1), !le(octuple_emul, 64)) then {
defvar emul_str = octuple_to_str<octuple_emul>.ret;
defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
(ivti.Vector ivti.RegClass:$rs1),
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>(inst)
vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.SEW)>;
}
}
} // Predicates = [HasStdExtV, HasStdExtF]

File diff suppressed because it is too large Load Diff

View File

@ -714,6 +714,15 @@ bool TypeInfer::EnforceSameNumElts(TypeSetByHwMode &V, TypeSetByHwMode &W) {
return Changed;
}
namespace {
struct TypeSizeComparator {
bool operator()(const TypeSize &LHS, const TypeSize &RHS) const {
return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) <
std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue());
}
};
} // end anonymous namespace
/// 1. Ensure that for each type T in A, there exists a type U in B,
/// such that T and U have equal size in bits.
/// 2. Ensure that for each type U in B, there exists a type T in A
@ -728,14 +737,16 @@ bool TypeInfer::EnforceSameSize(TypeSetByHwMode &A, TypeSetByHwMode &B) {
if (B.empty())
Changed |= EnforceAny(B);
auto NoSize = [](const SmallSet<TypeSize, 2> &Sizes, MVT T) -> bool {
typedef SmallSet<TypeSize, 2, TypeSizeComparator> TypeSizeSet;
auto NoSize = [](const TypeSizeSet &Sizes, MVT T) -> bool {
return !Sizes.count(T.getSizeInBits());
};
for (unsigned M : union_modes(A, B)) {
TypeSetByHwMode::SetType &AS = A.get(M);
TypeSetByHwMode::SetType &BS = B.get(M);
SmallSet<TypeSize, 2> AN, BN;
TypeSizeSet AN, BN;
for (MVT T : AS)
AN.insert(T.getSizeInBits());