1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 10:42:39 +01:00

[CostModel] Unify getCastInstrCost

Add the remaining cast instruction opcodes to the base implementation
of getUserCost and directly return the result. This allows
getInstructionThroughput to return getUserCost for the casts. This
has required changes to PPC and SystemZ because they implement
getUserCost and/or getCastInstrCost with adjustments for vector
operations. Adjusts have also been made in the remaining backends
that implement the method so that they still produce a cost of zero
or one for cost kinds other than throughput.

Differential Revision: https://reviews.llvm.org/D79848
This commit is contained in:
Sam Parker 2020-05-26 11:27:57 +01:00
parent 1239800045
commit 8420f6860c
8 changed files with 76 additions and 41 deletions

View File

@ -826,18 +826,18 @@ public:
return TTI::TCC_Expensive;
case Instruction::IntToPtr:
case Instruction::PtrToInt:
case Instruction::SIToFP:
case Instruction::UIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast:
if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) ==
TTI::TCC_Free)
return TTI::TCC_Free;
break;
case Instruction::FPExt:
case Instruction::SExt:
case Instruction::ZExt:
if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == TTI::TCC_Free)
return TTI::TCC_Free;
break;
case Instruction::AddrSpaceCast:
return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
}
// By default, just classify everything as 'basic'.
return TTI::TCC_Basic;

View File

@ -1325,10 +1325,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast:
case Instruction::AddrSpaceCast: {
Type *SrcTy = I->getOperand(0)->getType();
return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, CostKind, I);
}
case Instruction::AddrSpaceCast:
return getUserCost(I, CostKind);
case Instruction::ExtractElement: {
const ExtractElementInst *EEI = cast<ExtractElementInst>(I);
ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));

View File

@ -295,11 +295,18 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
}
}
// TODO: Allow non-throughput costs that aren't binary.
auto AdjustCost = [&CostKind](int Cost) {
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
};
EVT SrcTy = TLI->getValueType(DL, Src);
EVT DstTy = TLI->getValueType(DL, Dst);
if (!SrcTy.isSimple() || !DstTy.isSimple())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
static const TypeConversionCostTblEntry
ConversionTbl[] = {
@ -401,9 +408,9 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
}
int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,

View File

@ -173,6 +173,13 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
// TODO: Allow non-throughput costs that aren't binary.
auto AdjustCost = [&CostKind](int Cost) {
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
};
// Single to/from double precision conversions.
static const CostTblEntry NEONFltDblTbl[] = {
// Vector fptrunc/fpext conversions.
@ -185,14 +192,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
ISD == ISD::FP_EXTEND)) {
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Src);
if (const auto *Entry = CostTableLookup(NEONFltDblTbl, ISD, LT.second))
return LT.first * Entry->Cost;
return AdjustCost(LT.first * Entry->Cost);
}
EVT SrcTy = TLI->getValueType(DL, Src);
EVT DstTy = TLI->getValueType(DL, Dst);
if (!SrcTy.isSimple() || !DstTy.isSimple())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
// The extend of a load is free
if (I && isa<LoadInst>(I->getOperand(0))) {
@ -212,7 +219,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
};
if (const auto *Entry = ConvertCostTableLookup(
LoadConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
static const TypeConversionCostTblEntry MVELoadConversionTbl[] = {
{ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0},
@ -226,7 +233,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry =
ConvertCostTableLookup(MVELoadConversionTbl, ISD,
DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
}
@ -253,7 +260,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (auto *Entry = ConvertCostTableLookup(NEONDoubleWidthTbl, UserISD,
DstTy.getSimpleVT(),
SrcTy.getSimpleVT())) {
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
}
@ -347,7 +354,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(NEONVectorConversionTbl, ISD,
DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
// Scalar float to integer conversions.
@ -377,7 +384,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(NEONFloatConversionTbl, ISD,
DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
// Scalar integer to float conversions.
@ -408,7 +415,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(NEONIntegerConversionTbl,
ISD, DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
// MVE extend costs, taken from codegen tests. i8->i16 or i16->i32 is one
@ -433,7 +440,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(MVEVectorConversionTbl,
ISD, DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost * ST->getMVEVectorCostFactor();
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
}
// Scalar integer conversion costs.
@ -452,13 +459,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (const auto *Entry = ConvertCostTableLookup(ARMIntegerConversionTbl, ISD,
DstTy.getSimpleVT(),
SrcTy.getSimpleVT()))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy()
? ST->getMVEVectorCostFactor()
: 1;
return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return AdjustCost(
BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
}
int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,

View File

@ -263,7 +263,11 @@ unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
std::pair<int, MVT> SrcLT = TLI.getTypeLegalizationCost(DL, SrcTy);
std::pair<int, MVT> DstLT = TLI.getTypeLegalizationCost(DL, DstTy);
return std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
unsigned Cost = std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
// TODO: Allow non-throughput costs that aren't binary.
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
}
return 1;
}

View File

@ -212,7 +212,8 @@ int PPCTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
unsigned
PPCTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
TTI::TargetCostKind CostKind) {
if (U->getType()->isVectorTy()) {
// We already implement getCastInstrCost and perform the vector adjustment there.
if (!isa<CastInst>(U) && U->getType()->isVectorTy()) {
// Instructions that need to be split should cost more.
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, U->getType());
return LT.first * BaseT::getUserCost(U, Operands, CostKind);
@ -760,7 +761,11 @@ int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return vectorCostAdjustment(Cost, Opcode, Dst, Src);
Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src);
// TODO: Allow non-throughput costs that aren't binary.
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
}
int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,

View File

@ -691,6 +691,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
TTI::TargetCostKind CostKind,
const Instruction *I) {
// FIXME: Can the logic below also be used for these cost kinds?
if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) {
int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return BaseCost == 0 ? BaseCost : 1;
}
unsigned DstScalarBits = Dst->getScalarSizeInBits();
unsigned SrcScalarBits = Src->getScalarSizeInBits();

View File

@ -1368,6 +1368,13 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
// TODO: Allow non-throughput costs that aren't binary.
auto AdjustCost = [&CostKind](int Cost) {
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
};
// FIXME: Need a better design of the cost table to handle non-simple types of
// potential massive combinations (elem_num x src_type x dst_type).
@ -1969,7 +1976,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (ST->hasSSE2() && !ST->hasAVX()) {
if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
LTDest.second, LTSrc.second))
return LTSrc.first * Entry->Cost;
return AdjustCost(LTSrc.first * Entry->Cost);
}
EVT SrcTy = TLI->getValueType(DL, Src);
@ -1977,7 +1984,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
// The function getSimpleVT only handles simple value types.
if (!SrcTy.isSimple() || !DstTy.isSimple())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind);
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind));
MVT SimpleSrcTy = SrcTy.getSimpleVT();
MVT SimpleDstTy = DstTy.getSimpleVT();
@ -1986,59 +1993,59 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (ST->hasBWI())
if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
if (ST->hasDQI())
if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
if (ST->hasAVX512())
if (const auto *Entry = ConvertCostTableLookup(AVX512FConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
if (ST->hasBWI())
if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
if (ST->hasDQI())
if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
if (ST->hasAVX512())
if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
if (ST->hasAVX2()) {
if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
if (ST->hasAVX()) {
if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
if (ST->hasSSE41()) {
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
if (ST->hasSSE2()) {
if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
SimpleDstTy, SimpleSrcTy))
return Entry->Cost;
return AdjustCost(Entry->Cost);
}
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
}
int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,