1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-19 11:02:59 +02:00

[CodeGen][SVE] Legalisation of extends with scalable types

Summary:
This patch adds legalisation of extensions where the operand
of the extend is a legal scalable type but the result is not.

EXTRACT_SUBVECTOR is used to split the result, before
being replaced by target-specific [S|U]UNPK[HI|LO] operations.

For example:

```
zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
```
should emit:

```
uunpklo z2.h, z0.b
uunpkhi z1.h, z0.b
```

Reviewers: sdesmalen, efriedma, david-arm

Reviewed By: efriedma

Subscribers: tschuett, hiraditya, rkruppe, psnobl, huihuiz, cfe-commits, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79587
This commit is contained in:
Kerry McLaughlin 2020-06-05 11:11:50 +01:00
parent 2a463f17d8
commit 5e3af5dc50
6 changed files with 299 additions and 1 deletions

View File

@ -103,6 +103,17 @@ namespace llvm {
return VecTy;
}
/// Return a VT for a vector type whose attributes match ourselves
/// with the exception of the element type that is chosen by the caller.
EVT changeVectorElementType(EVT EltVT) const {
if (!isSimple())
return changeExtendedVectorElementType(EltVT);
MVT VecTy = MVT::getVectorVT(EltVT.V, getVectorElementCount());
assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
"Simple vector VT not representable by simple integer vector VT!");
return VecTy;
}
/// Return the type converted to an equivalently sized integer or vector
/// with integer element type. Similar to changeVectorElementTypeToInteger,
/// but also handles scalars.
@ -432,6 +443,7 @@ namespace llvm {
// These are all out-of-line to prevent users of this header file
// from having a dependency on Type.h.
EVT changeExtendedTypeToInteger() const;
EVT changeExtendedVectorElementType(EVT EltVT) const;
EVT changeExtendedVectorElementTypeToInteger() const;
static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth);
static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements,

View File

@ -4324,6 +4324,31 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_SUBVECTOR(SDNode *N) {
SDLoc dl(N);
SDValue BaseIdx = N->getOperand(1);
// TODO: We may be able to use this for types other than scalable
// vectors and fix those tests that expect BUILD_VECTOR to be used
if (OutVT.isScalableVector()) {
SDValue InOp0 = N->getOperand(0);
EVT InVT = InOp0.getValueType();
// Promote operands and see if this is handled by target lowering,
// Otherwise, use the BUILD_VECTOR approach below
if (getTypeAction(InVT) == TargetLowering::TypePromoteInteger) {
// Collect the (promoted) operands
SDValue Ops[] = { GetPromotedInteger(InOp0), BaseIdx };
EVT PromEltVT = Ops[0].getValueType().getVectorElementType();
assert(PromEltVT.bitsLE(NOutVTElem) &&
"Promoted operand has an element type greater than result");
EVT ExtVT = NOutVT.changeVectorElementType(PromEltVT);
SDValue Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), ExtVT, Ops);
return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Ext);
}
}
if (OutVT.isScalableVector())
report_fatal_error("Unable to promote scalable types using BUILD_VECTOR");
SDValue InOp0 = N->getOperand(0);
if (getTypeAction(InOp0.getValueType()) == TargetLowering::TypePromoteInteger)
InOp0 = GetPromotedInteger(N->getOperand(0));

View File

@ -26,6 +26,11 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const {
isScalableVector());
}
EVT EVT::changeExtendedVectorElementType(EVT EltVT) const {
LLVMContext &Context = LLVMTy->getContext();
return getVectorVT(Context, EltVT, getVectorElementCount());
}
EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) {
EVT VT;
VT.LLVMTy = IntegerType::get(Context, BitWidth);

View File

@ -901,6 +901,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SRA, VT, Custom);
if (VT.getScalarType() == MVT::i1)
setOperationAction(ISD::SETCC, VT, Custom);
} else {
for (auto VT : { MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32 })
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
}
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@ -8560,6 +8563,9 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
SelectionDAG &DAG) const {
assert(!Op.getValueType().isScalableVector() &&
"Unexpected scalable type for custom lowering EXTRACT_SUBVECTOR");
EVT VT = Op.getOperand(0).getValueType();
SDLoc dl(Op);
// Just in case...
@ -10662,7 +10668,45 @@ static SDValue performSVEAndCombine(SDNode *N,
if (DCI.isBeforeLegalizeOps())
return SDValue();
SelectionDAG &DAG = DCI.DAG;
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
// Zero/any extend of an unsigned unpack
if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
SDValue UnpkOp = Src->getOperand(0);
SDValue Dup = N->getOperand(1);
if (Dup.getOpcode() != AArch64ISD::DUP)
return SDValue();
SDLoc DL(N);
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
uint64_t ExtVal = C->getZExtValue();
// If the mask is fully covered by the unpack, we don't need to push
// a new AND onto the operand
EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
(ExtVal == 0xFFFF && EltTy == MVT::i16) ||
(ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
return Src;
// Truncate to prevent a DUP with an over wide constant
APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
// Otherwise, make sure we propagate the AND to the operand
// of the unpack
Dup = DAG.getNode(AArch64ISD::DUP, DL,
UnpkOp->getValueType(0),
DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
SDValue And = DAG.getNode(ISD::AND, DL,
UnpkOp->getValueType(0), UnpkOp, Dup);
return DAG.getNode(Opc, DL, N->getValueType(0), And);
}
SDValue Mask = N->getOperand(1);
if (!Src.hasOneUse())
@ -10672,7 +10716,7 @@ static SDValue performSVEAndCombine(SDNode *N,
// SVE load instructions perform an implicit zero-extend, which makes them
// perfect candidates for combining.
switch (Src->getOpcode()) {
switch (Opc) {
case AArch64ISD::LD1:
case AArch64ISD::LDNF1:
case AArch64ISD::LDFF1:
@ -13256,9 +13300,41 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
if (DCI.isBeforeLegalizeOps())
return SDValue();
SDLoc DL(N);
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
// Sign extend of an unsigned unpack -> signed unpack
if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
: AArch64ISD::SUNPKLO;
// Push the sign extend to the operand of the unpack
// This is necessary where, for example, the operand of the unpack
// is another unpack:
// 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
// ->
// 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
// ->
// 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
SDValue ExtOp = Src->getOperand(0);
auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
EVT EltTy = VT.getVectorElementType();
assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
"Sign extending from an invalid type");
EVT ExtVT = EVT::getVectorVT(*DAG.getContext(),
VT.getVectorElementType(),
VT.getVectorElementCount() * 2);
SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
ExtOp, DAG.getValueType(ExtVT));
return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
}
// SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
// for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
unsigned NewOpc;
@ -13747,6 +13823,40 @@ static std::pair<SDValue, SDValue> splitInt128(SDValue N, SelectionDAG &DAG) {
return std::make_pair(Lo, Hi);
}
void AArch64TargetLowering::ReplaceExtractSubVectorResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
SDValue In = N->getOperand(0);
EVT InVT = In.getValueType();
// Common code will handle these just fine.
if (!InVT.isScalableVector() || !InVT.isInteger())
return;
SDLoc DL(N);
EVT VT = N->getValueType(0);
// The following checks bail if this is not a halving operation.
ElementCount ResEC = VT.getVectorElementCount();
if (InVT.getVectorElementCount().Min != (ResEC.Min * 2))
return;
auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!CIndex)
return;
unsigned Index = CIndex->getZExtValue();
if ((Index != 0) && (Index != ResEC.Min))
return;
unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
}
// Create an even/odd pair of X registers holding integer value V.
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
SDLoc dl(V.getNode());
@ -13899,6 +14009,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.append({Pair, Result.getValue(2) /* Chain */});
return;
}
case ISD::EXTRACT_SUBVECTOR:
ReplaceExtractSubVectorResults(N, Results, DAG);
return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);
assert((VT == MVT::i8 || VT == MVT::i16) &&

View File

@ -889,6 +889,9 @@ private:
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const override;
void ReplaceExtractSubVectorResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const;
bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;

View File

@ -186,3 +186,143 @@ define <vscale x 2 x i64> @zext_i32_i64(<vscale x 2 x i32> %a) {
%r = zext <vscale x 2 x i32> %a to <vscale x 2 x i64>
ret <vscale x 2 x i64> %r
}
; Extending to illegal types
define <vscale x 16 x i16> @sext_b_to_h(<vscale x 16 x i8> %a) {
; CHECK-LABEL: sext_b_to_h:
; CHECK: // %bb.0:
; CHECK-NEXT: sunpklo z2.h, z0.b
; CHECK-NEXT: sunpkhi z1.h, z0.b
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i16>
ret <vscale x 16 x i16> %ext
}
define <vscale x 8 x i32> @sext_h_to_s(<vscale x 8 x i16> %a) {
; CHECK-LABEL: sext_h_to_s:
; CHECK: // %bb.0:
; CHECK-NEXT: sunpklo z2.s, z0.h
; CHECK-NEXT: sunpkhi z1.s, z0.h
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
ret <vscale x 8 x i32> %ext
}
define <vscale x 4 x i64> @sext_s_to_d(<vscale x 4 x i32> %a) {
; CHECK-LABEL: sext_s_to_d:
; CHECK: // %bb.0:
; CHECK-NEXT: sunpklo z2.d, z0.s
; CHECK-NEXT: sunpkhi z1.d, z0.s
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %a to <vscale x 4 x i64>
ret <vscale x 4 x i64> %ext
}
define <vscale x 16 x i32> @sext_b_to_s(<vscale x 16 x i8> %a) {
; CHECK-LABEL: sext_b_to_s:
; CHECK: // %bb.0:
; CHECK-NEXT: sunpklo z1.h, z0.b
; CHECK-NEXT: sunpkhi z3.h, z0.b
; CHECK-NEXT: sunpklo z0.s, z1.h
; CHECK-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEXT: sunpklo z2.s, z3.h
; CHECK-NEXT: sunpkhi z3.s, z3.h
; CHECK-NEXT: ret
%ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
ret <vscale x 16 x i32> %ext
}
define <vscale x 16 x i64> @sext_b_to_d(<vscale x 16 x i8> %a) {
; CHECK-LABEL: sext_b_to_d:
; CHECK: // %bb.0:
; CHECK-NEXT: sunpklo z1.h, z0.b
; CHECK-NEXT: sunpkhi z0.h, z0.b
; CHECK-NEXT: sunpklo z2.s, z1.h
; CHECK-NEXT: sunpkhi z3.s, z1.h
; CHECK-NEXT: sunpklo z5.s, z0.h
; CHECK-NEXT: sunpkhi z7.s, z0.h
; CHECK-NEXT: sunpklo z0.d, z2.s
; CHECK-NEXT: sunpkhi z1.d, z2.s
; CHECK-NEXT: sunpklo z2.d, z3.s
; CHECK-NEXT: sunpkhi z3.d, z3.s
; CHECK-NEXT: sunpklo z4.d, z5.s
; CHECK-NEXT: sunpkhi z5.d, z5.s
; CHECK-NEXT: sunpklo z6.d, z7.s
; CHECK-NEXT: sunpkhi z7.d, z7.s
; CHECK-NEXT: ret
%ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
ret <vscale x 16 x i64> %ext
}
define <vscale x 16 x i16> @zext_b_to_h(<vscale x 16 x i8> %a) {
; CHECK-LABEL: zext_b_to_h:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpklo z2.h, z0.b
; CHECK-NEXT: uunpkhi z1.h, z0.b
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
ret <vscale x 16 x i16> %ext
}
define <vscale x 8 x i32> @zext_h_to_s(<vscale x 8 x i16> %a) {
; CHECK-LABEL: zext_h_to_s:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpklo z2.s, z0.h
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
ret <vscale x 8 x i32> %ext
}
define <vscale x 4 x i64> @zext_s_to_d(<vscale x 4 x i32> %a) {
; CHECK-LABEL: zext_s_to_d:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpklo z2.d, z0.s
; CHECK-NEXT: uunpkhi z1.d, z0.s
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %a to <vscale x 4 x i64>
ret <vscale x 4 x i64> %ext
}
define <vscale x 16 x i32> @zext_b_to_s(<vscale x 16 x i8> %a) {
; CHECK-LABEL: zext_b_to_s:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpklo z1.h, z0.b
; CHECK-NEXT: uunpkhi z3.h, z0.b
; CHECK-NEXT: uunpklo z0.s, z1.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: uunpklo z2.s, z3.h
; CHECK-NEXT: uunpkhi z3.s, z3.h
; CHECK-NEXT: ret
%ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
ret <vscale x 16 x i32> %ext
}
define <vscale x 16 x i64> @zext_b_to_d(<vscale x 16 x i8> %a) {
; CHECK-LABEL: zext_b_to_d:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpklo z1.h, z0.b
; CHECK-NEXT: uunpkhi z0.h, z0.b
; CHECK-NEXT: uunpklo z2.s, z1.h
; CHECK-NEXT: uunpkhi z3.s, z1.h
; CHECK-NEXT: uunpklo z5.s, z0.h
; CHECK-NEXT: uunpkhi z7.s, z0.h
; CHECK-NEXT: uunpklo z0.d, z2.s
; CHECK-NEXT: uunpkhi z1.d, z2.s
; CHECK-NEXT: uunpklo z2.d, z3.s
; CHECK-NEXT: uunpkhi z3.d, z3.s
; CHECK-NEXT: uunpklo z4.d, z5.s
; CHECK-NEXT: uunpkhi z5.d, z5.s
; CHECK-NEXT: uunpklo z6.d, z7.s
; CHECK-NEXT: uunpkhi z7.d, z7.s
; CHECK-NEXT: ret
%ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
ret <vscale x 16 x i64> %ext
}