1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 18:54:02 +01:00

[CodeGen][SVE] Calculate correct type legalization for scalable vectors.

This patch updates TargetLoweringBase::computeRegisterProperties and
TargetLoweringBase::getTypeConversion to support scalable vectors,
and make the right calls on how to legalise them. These changes are required
to legalise both MVTs and EVTs.

Reviewers: efriedma, david-arm, ctetreau

Reviewed By: efriedma

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80640
This commit is contained in:
Sander de Smalen 2020-06-05 12:54:39 +01:00
parent b602e0720e
commit b11af615da
8 changed files with 115 additions and 24 deletions

View File

@ -210,6 +210,13 @@ public:
TypeWidenVector, // This vector should be widened into a larger vector.
TypePromoteFloat, // Replace this float with a larger one.
TypeSoftPromoteHalf, // Soften half to i16 and use float to do arithmetic.
TypeScalarizeScalableVector, // This action is explicitly left unimplemented.
// While it is theoretically possible to
// legalize operations on scalable types with a
// loop that handles the vscale * #lanes of the
// vector, this is non-trivial at SelectionDAG
// level and these types are better to be
// widened or promoted.
};
/// LegalizeKind holds the legalization kind that needs to happen to EVT
@ -412,7 +419,7 @@ public:
virtual TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const {
// The default action for one element vectors is to scalarize
if (VT.getVectorNumElements() == 1)
if (VT.getVectorElementCount() == 1)
return TypeScalarizeVector;
// The default action for an odd-width vector is to widen.
if (!VT.isPow2VectorType())

View File

@ -15,6 +15,7 @@
#ifndef LLVM_SUPPORT_TYPESIZE_H
#define LLVM_SUPPORT_TYPESIZE_H
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/WithColor.h"
#include <cstdint>
@ -49,6 +50,12 @@ public:
bool operator!=(const ElementCount& RHS) const {
return !(*this == RHS);
}
bool operator==(unsigned RHS) const { return Min == RHS && !Scalable; }
bool operator!=(unsigned RHS) const { return !(*this == RHS); }
ElementCount NextPowerOf2() const {
return ElementCount(llvm::NextPowerOf2(Min), Scalable);
}
};
// This class is used to represent the size of types. If the type is of fixed

View File

@ -344,6 +344,8 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BITCAST(SDNode *N) {
return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
BitConvertToInteger(GetScalarizedVector(InOp)));
break;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error("Scalarization of scalable vectors is not supported.");
case TargetLowering::TypeSplitVector: {
if (!NOutVT.isVector()) {
// For example, i32 = BITCAST v2i16 on alpha. Convert the split

View File

@ -245,6 +245,9 @@ bool DAGTypeLegalizer::run() {
case TargetLowering::TypeLegal:
LLVM_DEBUG(dbgs() << "Legal result type\n");
break;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error(
"Scalarization of scalable vectors is not supported.");
// The following calls must take care of *all* of the node's results,
// not just the illegal result they were passed (this includes results
// with a legal type). Results can be remapped using ReplaceValueWith,
@ -307,6 +310,9 @@ ScanOperands:
case TargetLowering::TypeLegal:
LLVM_DEBUG(dbgs() << "Legal operand\n");
continue;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error(
"Scalarization of scalable vectors is not supported.");
// The following calls must either replace all of the node's results
// using ReplaceValueWith, and return "false"; or update the node's
// operands in place, and return "true".

View File

@ -83,6 +83,8 @@ void DAGTypeLegalizer::ExpandRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi) {
Lo = DAG.getNode(ISD::BITCAST, dl, NOutVT, Lo);
Hi = DAG.getNode(ISD::BITCAST, dl, NOutVT, Hi);
return;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error("Scalarization of scalable vectors is not supported.");
case TargetLowering::TypeWidenVector: {
assert(!(InVT.getVectorNumElements() & 1) && "Unsupported BITCAST");
InOp = GetWidenedVector(InOp);

View File

@ -1063,6 +1063,8 @@ void DAGTypeLegalizer::SplitVecRes_BITCAST(SDNode *N, SDValue &Lo,
Lo = DAG.getNode(ISD::BITCAST, dl, LoVT, Lo);
Hi = DAG.getNode(ISD::BITCAST, dl, HiVT, Hi);
return;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error("Scalarization of scalable vectors is not supported.");
}
// In the general case, convert the input to an integer and split it by hand.
@ -3465,6 +3467,8 @@ SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
switch (getTypeAction(InVT)) {
case TargetLowering::TypeLegal:
break;
case TargetLowering::TypeScalarizeScalableVector:
report_fatal_error("Scalarization of scalable vectors is not supported.");
case TargetLowering::TypePromoteInteger: {
// If the incoming type is a vector that is being promoted, then
// we know that the elements are arranged differently and that we

View File

@ -823,9 +823,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
"Promote may not follow Expand or Promote");
if (LA == TypeSplitVector)
return LegalizeKind(LA,
EVT::getVectorVT(Context, SVT.getVectorElementType(),
SVT.getVectorNumElements() / 2));
return LegalizeKind(LA, SVT.getHalfNumVectorElementsVT());
if (LA == TypeScalarizeVector)
return LegalizeKind(LA, SVT.getVectorElementType());
return LegalizeKind(LA, NVT);
@ -852,13 +850,16 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
}
// Handle vector types.
unsigned NumElts = VT.getVectorNumElements();
ElementCount NumElts = VT.getVectorElementCount();
EVT EltVT = VT.getVectorElementType();
// Vectors with only one element are always scalarized.
if (NumElts == 1)
return LegalizeKind(TypeScalarizeVector, EltVT);
if (VT.getVectorElementCount() == ElementCount(1, true))
report_fatal_error("Cannot legalize this vector");
// Try to widen vector elements until the element type is a power of two and
// promote it to a legal type later on, for example:
// <3 x i8> -> <4 x i8> -> <4 x i32>
@ -866,7 +867,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
// Vectors with a number of elements that is not a power of two are always
// widened, for example <3 x i8> -> <4 x i8>.
if (!VT.isPow2VectorType()) {
NumElts = (unsigned)NextPowerOf2(NumElts);
NumElts = NumElts.NextPowerOf2();
EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts);
return LegalizeKind(TypeWidenVector, NVT);
}
@ -915,7 +916,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
// If there is no wider legal type, split the vector.
while (true) {
// Round up to the next power of 2.
NumElts = (unsigned)NextPowerOf2(NumElts);
NumElts = NumElts.NextPowerOf2();
// If there is no simple vector type with this many elements then there
// cannot be a larger legal vector type. Note that this assumes that
@ -938,7 +939,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
}
// Vectors with illegal element types are expanded.
EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorNumElements() / 2);
EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorElementCount() / 2);
return LegalizeKind(TypeSplitVector, NVT);
}
@ -1257,7 +1258,7 @@ void TargetLoweringBase::computeRegisterProperties(
continue;
MVT EltVT = VT.getVectorElementType();
unsigned NElts = VT.getVectorNumElements();
ElementCount EC = VT.getVectorElementCount();
bool IsLegalWiderType = false;
bool IsScalable = VT.isScalableVector();
LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT);
@ -1274,8 +1275,7 @@ void TargetLoweringBase::computeRegisterProperties(
// Promote vectors of integers to vectors with the same number
// of elements, with a wider element type.
if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() &&
SVT.getVectorNumElements() == NElts &&
SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
SVT.getVectorElementCount() == EC && isTypeLegal(SVT)) {
TransformToType[i] = SVT;
RegisterTypeForVT[i] = SVT;
NumRegistersForVT[i] = 1;
@ -1290,13 +1290,13 @@ void TargetLoweringBase::computeRegisterProperties(
}
case TypeWidenVector:
if (isPowerOf2_32(NElts)) {
if (isPowerOf2_32(EC.Min)) {
// Try to widen the vector.
for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) {
MVT SVT = (MVT::SimpleValueType) nVT;
if (SVT.getVectorElementType() == EltVT
&& SVT.getVectorNumElements() > NElts
&& SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
if (SVT.getVectorElementType() == EltVT &&
SVT.isScalableVector() == IsScalable &&
SVT.getVectorElementCount().Min > EC.Min && isTypeLegal(SVT)) {
TransformToType[i] = SVT;
RegisterTypeForVT[i] = SVT;
NumRegistersForVT[i] = 1;
@ -1340,10 +1340,12 @@ void TargetLoweringBase::computeRegisterProperties(
ValueTypeActions.setTypeAction(VT, TypeScalarizeVector);
else if (PreferredAction == TypeSplitVector)
ValueTypeActions.setTypeAction(VT, TypeSplitVector);
else if (EC.Min > 1)
ValueTypeActions.setTypeAction(VT, TypeSplitVector);
else
// Set type action according to the number of elements.
ValueTypeActions.setTypeAction(VT, NElts == 1 ? TypeScalarizeVector
: TypeSplitVector);
ValueTypeActions.setTypeAction(VT, EC.Scalable
? TypeScalarizeScalableVector
: TypeScalarizeVector);
} else {
TransformToType[i] = NVT;
ValueTypeActions.setTypeAction(VT, TypeWidenVector);

View File

@ -17,9 +17,7 @@
#include "llvm/Target/TargetMachine.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
namespace llvm {
class AArch64SelectionDAGTest : public testing::Test {
protected:
@ -41,8 +39,8 @@ protected:
return;
TargetOptions Options;
TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine*>(
T->createTargetMachine("AArch64", "", "", Options, None, None,
TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
T->createTargetMachine("AArch64", "", "+sve", Options, None, None,
CodeGenOpt::Aggressive)));
if (!TM)
return;
@ -69,6 +67,14 @@ protected:
DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr);
}
TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
}
EVT getTypeToTransformTo(EVT VT) {
return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
}
LLVMContext Context;
std::unique_ptr<LLVMTargetMachine> TM;
std::unique_ptr<Module> M;
@ -377,4 +383,59 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
EXPECT_EQ(SplatIdx, 0);
}
} // end anonymous namespace
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
if (!TM)
return;
MVT VT = MVT::nxv4i64;
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_PromoteScalableMVT) {
if (!TM)
return;
MVT VT = MVT::nxv2i32;
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypePromoteInteger);
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeMVT_nxv1f32) {
if (!TM)
return;
MVT VT = MVT::nxv1f32;
EXPECT_NE(getTypeAction(VT), TargetLoweringBase::TypeScalarizeVector);
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableEVT) {
if (!TM)
return;
EVT VT = EVT::getVectorVT(Context, MVT::i64, 256, true);
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
EXPECT_EQ(getTypeToTransformTo(VT), VT.getHalfNumVectorElementsVT(Context));
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_WidenScalableEVT) {
if (!TM)
return;
EVT FromVT = EVT::getVectorVT(Context, MVT::i64, 6, true);
EVT ToVT = EVT::getVectorVT(Context, MVT::i64, 8, true);
EXPECT_EQ(getTypeAction(FromVT), TargetLoweringBase::TypeWidenVector);
EXPECT_EQ(getTypeToTransformTo(FromVT), ToVT);
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeEVT_nxv1f128) {
if (!TM)
return;
EVT FromVT = EVT::getVectorVT(Context, MVT::f128, 1, true);
EXPECT_DEATH(getTypeAction(FromVT), "Cannot legalize this vector");
}
} // end namespace llvm