mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 18:54:02 +01:00
[CodeGen][SVE] Calculate correct type legalization for scalable vectors.
This patch updates TargetLoweringBase::computeRegisterProperties and TargetLoweringBase::getTypeConversion to support scalable vectors, and make the right calls on how to legalise them. These changes are required to legalise both MVTs and EVTs. Reviewers: efriedma, david-arm, ctetreau Reviewed By: efriedma Tags: #llvm Differential Revision: https://reviews.llvm.org/D80640
This commit is contained in:
parent
b602e0720e
commit
b11af615da
@ -210,6 +210,13 @@ public:
|
||||
TypeWidenVector, // This vector should be widened into a larger vector.
|
||||
TypePromoteFloat, // Replace this float with a larger one.
|
||||
TypeSoftPromoteHalf, // Soften half to i16 and use float to do arithmetic.
|
||||
TypeScalarizeScalableVector, // This action is explicitly left unimplemented.
|
||||
// While it is theoretically possible to
|
||||
// legalize operations on scalable types with a
|
||||
// loop that handles the vscale * #lanes of the
|
||||
// vector, this is non-trivial at SelectionDAG
|
||||
// level and these types are better to be
|
||||
// widened or promoted.
|
||||
};
|
||||
|
||||
/// LegalizeKind holds the legalization kind that needs to happen to EVT
|
||||
@ -412,7 +419,7 @@ public:
|
||||
virtual TargetLoweringBase::LegalizeTypeAction
|
||||
getPreferredVectorAction(MVT VT) const {
|
||||
// The default action for one element vectors is to scalarize
|
||||
if (VT.getVectorNumElements() == 1)
|
||||
if (VT.getVectorElementCount() == 1)
|
||||
return TypeScalarizeVector;
|
||||
// The default action for an odd-width vector is to widen.
|
||||
if (!VT.isPow2VectorType())
|
||||
|
@ -15,6 +15,7 @@
|
||||
#ifndef LLVM_SUPPORT_TYPESIZE_H
|
||||
#define LLVM_SUPPORT_TYPESIZE_H
|
||||
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
|
||||
#include <cstdint>
|
||||
@ -49,6 +50,12 @@ public:
|
||||
bool operator!=(const ElementCount& RHS) const {
|
||||
return !(*this == RHS);
|
||||
}
|
||||
bool operator==(unsigned RHS) const { return Min == RHS && !Scalable; }
|
||||
bool operator!=(unsigned RHS) const { return !(*this == RHS); }
|
||||
|
||||
ElementCount NextPowerOf2() const {
|
||||
return ElementCount(llvm::NextPowerOf2(Min), Scalable);
|
||||
}
|
||||
};
|
||||
|
||||
// This class is used to represent the size of types. If the type is of fixed
|
||||
|
@ -344,6 +344,8 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BITCAST(SDNode *N) {
|
||||
return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
|
||||
BitConvertToInteger(GetScalarizedVector(InOp)));
|
||||
break;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error("Scalarization of scalable vectors is not supported.");
|
||||
case TargetLowering::TypeSplitVector: {
|
||||
if (!NOutVT.isVector()) {
|
||||
// For example, i32 = BITCAST v2i16 on alpha. Convert the split
|
||||
|
@ -245,6 +245,9 @@ bool DAGTypeLegalizer::run() {
|
||||
case TargetLowering::TypeLegal:
|
||||
LLVM_DEBUG(dbgs() << "Legal result type\n");
|
||||
break;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error(
|
||||
"Scalarization of scalable vectors is not supported.");
|
||||
// The following calls must take care of *all* of the node's results,
|
||||
// not just the illegal result they were passed (this includes results
|
||||
// with a legal type). Results can be remapped using ReplaceValueWith,
|
||||
@ -307,6 +310,9 @@ ScanOperands:
|
||||
case TargetLowering::TypeLegal:
|
||||
LLVM_DEBUG(dbgs() << "Legal operand\n");
|
||||
continue;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error(
|
||||
"Scalarization of scalable vectors is not supported.");
|
||||
// The following calls must either replace all of the node's results
|
||||
// using ReplaceValueWith, and return "false"; or update the node's
|
||||
// operands in place, and return "true".
|
||||
|
@ -83,6 +83,8 @@ void DAGTypeLegalizer::ExpandRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi) {
|
||||
Lo = DAG.getNode(ISD::BITCAST, dl, NOutVT, Lo);
|
||||
Hi = DAG.getNode(ISD::BITCAST, dl, NOutVT, Hi);
|
||||
return;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error("Scalarization of scalable vectors is not supported.");
|
||||
case TargetLowering::TypeWidenVector: {
|
||||
assert(!(InVT.getVectorNumElements() & 1) && "Unsupported BITCAST");
|
||||
InOp = GetWidenedVector(InOp);
|
||||
|
@ -1063,6 +1063,8 @@ void DAGTypeLegalizer::SplitVecRes_BITCAST(SDNode *N, SDValue &Lo,
|
||||
Lo = DAG.getNode(ISD::BITCAST, dl, LoVT, Lo);
|
||||
Hi = DAG.getNode(ISD::BITCAST, dl, HiVT, Hi);
|
||||
return;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error("Scalarization of scalable vectors is not supported.");
|
||||
}
|
||||
|
||||
// In the general case, convert the input to an integer and split it by hand.
|
||||
@ -3465,6 +3467,8 @@ SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
|
||||
switch (getTypeAction(InVT)) {
|
||||
case TargetLowering::TypeLegal:
|
||||
break;
|
||||
case TargetLowering::TypeScalarizeScalableVector:
|
||||
report_fatal_error("Scalarization of scalable vectors is not supported.");
|
||||
case TargetLowering::TypePromoteInteger: {
|
||||
// If the incoming type is a vector that is being promoted, then
|
||||
// we know that the elements are arranged differently and that we
|
||||
|
@ -823,9 +823,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
|
||||
"Promote may not follow Expand or Promote");
|
||||
|
||||
if (LA == TypeSplitVector)
|
||||
return LegalizeKind(LA,
|
||||
EVT::getVectorVT(Context, SVT.getVectorElementType(),
|
||||
SVT.getVectorNumElements() / 2));
|
||||
return LegalizeKind(LA, SVT.getHalfNumVectorElementsVT());
|
||||
if (LA == TypeScalarizeVector)
|
||||
return LegalizeKind(LA, SVT.getVectorElementType());
|
||||
return LegalizeKind(LA, NVT);
|
||||
@ -852,13 +850,16 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
|
||||
}
|
||||
|
||||
// Handle vector types.
|
||||
unsigned NumElts = VT.getVectorNumElements();
|
||||
ElementCount NumElts = VT.getVectorElementCount();
|
||||
EVT EltVT = VT.getVectorElementType();
|
||||
|
||||
// Vectors with only one element are always scalarized.
|
||||
if (NumElts == 1)
|
||||
return LegalizeKind(TypeScalarizeVector, EltVT);
|
||||
|
||||
if (VT.getVectorElementCount() == ElementCount(1, true))
|
||||
report_fatal_error("Cannot legalize this vector");
|
||||
|
||||
// Try to widen vector elements until the element type is a power of two and
|
||||
// promote it to a legal type later on, for example:
|
||||
// <3 x i8> -> <4 x i8> -> <4 x i32>
|
||||
@ -866,7 +867,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
|
||||
// Vectors with a number of elements that is not a power of two are always
|
||||
// widened, for example <3 x i8> -> <4 x i8>.
|
||||
if (!VT.isPow2VectorType()) {
|
||||
NumElts = (unsigned)NextPowerOf2(NumElts);
|
||||
NumElts = NumElts.NextPowerOf2();
|
||||
EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts);
|
||||
return LegalizeKind(TypeWidenVector, NVT);
|
||||
}
|
||||
@ -915,7 +916,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
|
||||
// If there is no wider legal type, split the vector.
|
||||
while (true) {
|
||||
// Round up to the next power of 2.
|
||||
NumElts = (unsigned)NextPowerOf2(NumElts);
|
||||
NumElts = NumElts.NextPowerOf2();
|
||||
|
||||
// If there is no simple vector type with this many elements then there
|
||||
// cannot be a larger legal vector type. Note that this assumes that
|
||||
@ -938,7 +939,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
|
||||
}
|
||||
|
||||
// Vectors with illegal element types are expanded.
|
||||
EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorNumElements() / 2);
|
||||
EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorElementCount() / 2);
|
||||
return LegalizeKind(TypeSplitVector, NVT);
|
||||
}
|
||||
|
||||
@ -1257,7 +1258,7 @@ void TargetLoweringBase::computeRegisterProperties(
|
||||
continue;
|
||||
|
||||
MVT EltVT = VT.getVectorElementType();
|
||||
unsigned NElts = VT.getVectorNumElements();
|
||||
ElementCount EC = VT.getVectorElementCount();
|
||||
bool IsLegalWiderType = false;
|
||||
bool IsScalable = VT.isScalableVector();
|
||||
LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT);
|
||||
@ -1274,8 +1275,7 @@ void TargetLoweringBase::computeRegisterProperties(
|
||||
// Promote vectors of integers to vectors with the same number
|
||||
// of elements, with a wider element type.
|
||||
if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() &&
|
||||
SVT.getVectorNumElements() == NElts &&
|
||||
SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
|
||||
SVT.getVectorElementCount() == EC && isTypeLegal(SVT)) {
|
||||
TransformToType[i] = SVT;
|
||||
RegisterTypeForVT[i] = SVT;
|
||||
NumRegistersForVT[i] = 1;
|
||||
@ -1290,13 +1290,13 @@ void TargetLoweringBase::computeRegisterProperties(
|
||||
}
|
||||
|
||||
case TypeWidenVector:
|
||||
if (isPowerOf2_32(NElts)) {
|
||||
if (isPowerOf2_32(EC.Min)) {
|
||||
// Try to widen the vector.
|
||||
for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) {
|
||||
MVT SVT = (MVT::SimpleValueType) nVT;
|
||||
if (SVT.getVectorElementType() == EltVT
|
||||
&& SVT.getVectorNumElements() > NElts
|
||||
&& SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
|
||||
if (SVT.getVectorElementType() == EltVT &&
|
||||
SVT.isScalableVector() == IsScalable &&
|
||||
SVT.getVectorElementCount().Min > EC.Min && isTypeLegal(SVT)) {
|
||||
TransformToType[i] = SVT;
|
||||
RegisterTypeForVT[i] = SVT;
|
||||
NumRegistersForVT[i] = 1;
|
||||
@ -1340,10 +1340,12 @@ void TargetLoweringBase::computeRegisterProperties(
|
||||
ValueTypeActions.setTypeAction(VT, TypeScalarizeVector);
|
||||
else if (PreferredAction == TypeSplitVector)
|
||||
ValueTypeActions.setTypeAction(VT, TypeSplitVector);
|
||||
else if (EC.Min > 1)
|
||||
ValueTypeActions.setTypeAction(VT, TypeSplitVector);
|
||||
else
|
||||
// Set type action according to the number of elements.
|
||||
ValueTypeActions.setTypeAction(VT, NElts == 1 ? TypeScalarizeVector
|
||||
: TypeSplitVector);
|
||||
ValueTypeActions.setTypeAction(VT, EC.Scalable
|
||||
? TypeScalarizeScalableVector
|
||||
: TypeScalarizeVector);
|
||||
} else {
|
||||
TransformToType[i] = NVT;
|
||||
ValueTypeActions.setTypeAction(VT, TypeWidenVector);
|
||||
|
@ -17,9 +17,7 @@
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
namespace llvm {
|
||||
|
||||
class AArch64SelectionDAGTest : public testing::Test {
|
||||
protected:
|
||||
@ -41,8 +39,8 @@ protected:
|
||||
return;
|
||||
|
||||
TargetOptions Options;
|
||||
TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine*>(
|
||||
T->createTargetMachine("AArch64", "", "", Options, None, None,
|
||||
TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
|
||||
T->createTargetMachine("AArch64", "", "+sve", Options, None, None,
|
||||
CodeGenOpt::Aggressive)));
|
||||
if (!TM)
|
||||
return;
|
||||
@ -69,6 +67,14 @@ protected:
|
||||
DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
|
||||
return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
|
||||
}
|
||||
|
||||
EVT getTypeToTransformTo(EVT VT) {
|
||||
return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
|
||||
}
|
||||
|
||||
LLVMContext Context;
|
||||
std::unique_ptr<LLVMTargetMachine> TM;
|
||||
std::unique_ptr<Module> M;
|
||||
@ -377,4 +383,59 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
|
||||
EXPECT_EQ(SplatIdx, 0);
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
MVT VT = MVT::nxv4i64;
|
||||
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
|
||||
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_PromoteScalableMVT) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
MVT VT = MVT::nxv2i32;
|
||||
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypePromoteInteger);
|
||||
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeMVT_nxv1f32) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
MVT VT = MVT::nxv1f32;
|
||||
EXPECT_NE(getTypeAction(VT), TargetLoweringBase::TypeScalarizeVector);
|
||||
ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableEVT) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
EVT VT = EVT::getVectorVT(Context, MVT::i64, 256, true);
|
||||
EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
|
||||
EXPECT_EQ(getTypeToTransformTo(VT), VT.getHalfNumVectorElementsVT(Context));
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_WidenScalableEVT) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
EVT FromVT = EVT::getVectorVT(Context, MVT::i64, 6, true);
|
||||
EVT ToVT = EVT::getVectorVT(Context, MVT::i64, 8, true);
|
||||
|
||||
EXPECT_EQ(getTypeAction(FromVT), TargetLoweringBase::TypeWidenVector);
|
||||
EXPECT_EQ(getTypeToTransformTo(FromVT), ToVT);
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeEVT_nxv1f128) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
EVT FromVT = EVT::getVectorVT(Context, MVT::f128, 1, true);
|
||||
EXPECT_DEATH(getTypeAction(FromVT), "Cannot legalize this vector");
|
||||
}
|
||||
|
||||
} // end namespace llvm
|
||||
|
Loading…
Reference in New Issue
Block a user