mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 10:42:39 +01:00
[CostModel] Unify getCastInstrCost
Add the remaining cast instruction opcodes to the base implementation of getUserCost and directly return the result. This allows getInstructionThroughput to return getUserCost for the casts. This has required changes to PPC and SystemZ because they implement getUserCost and/or getCastInstrCost with adjustments for vector operations. Adjusts have also been made in the remaining backends that implement the method so that they still produce a cost of zero or one for cost kinds other than throughput. Differential Revision: https://reviews.llvm.org/D79848
This commit is contained in:
parent
1239800045
commit
8420f6860c
@ -826,18 +826,18 @@ public:
|
||||
return TTI::TCC_Expensive;
|
||||
case Instruction::IntToPtr:
|
||||
case Instruction::PtrToInt:
|
||||
case Instruction::SIToFP:
|
||||
case Instruction::UIToFP:
|
||||
case Instruction::FPToUI:
|
||||
case Instruction::FPToSI:
|
||||
case Instruction::Trunc:
|
||||
case Instruction::FPTrunc:
|
||||
case Instruction::BitCast:
|
||||
if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) ==
|
||||
TTI::TCC_Free)
|
||||
return TTI::TCC_Free;
|
||||
break;
|
||||
case Instruction::FPExt:
|
||||
case Instruction::SExt:
|
||||
case Instruction::ZExt:
|
||||
if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == TTI::TCC_Free)
|
||||
return TTI::TCC_Free;
|
||||
break;
|
||||
case Instruction::AddrSpaceCast:
|
||||
return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
|
||||
}
|
||||
// By default, just classify everything as 'basic'.
|
||||
return TTI::TCC_Basic;
|
||||
|
@ -1325,10 +1325,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
|
||||
case Instruction::Trunc:
|
||||
case Instruction::FPTrunc:
|
||||
case Instruction::BitCast:
|
||||
case Instruction::AddrSpaceCast: {
|
||||
Type *SrcTy = I->getOperand(0)->getType();
|
||||
return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, CostKind, I);
|
||||
}
|
||||
case Instruction::AddrSpaceCast:
|
||||
return getUserCost(I, CostKind);
|
||||
case Instruction::ExtractElement: {
|
||||
const ExtractElementInst *EEI = cast<ExtractElementInst>(I);
|
||||
ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
|
||||
|
@ -295,11 +295,18 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Allow non-throughput costs that aren't binary.
|
||||
auto AdjustCost = [&CostKind](int Cost) {
|
||||
if (CostKind != TTI::TCK_RecipThroughput)
|
||||
return Cost == 0 ? 0 : 1;
|
||||
return Cost;
|
||||
};
|
||||
|
||||
EVT SrcTy = TLI->getValueType(DL, Src);
|
||||
EVT DstTy = TLI->getValueType(DL, Dst);
|
||||
|
||||
if (!SrcTy.isSimple() || !DstTy.isSimple())
|
||||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
|
||||
|
||||
static const TypeConversionCostTblEntry
|
||||
ConversionTbl[] = {
|
||||
@ -401,9 +408,9 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
|
||||
DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
|
||||
}
|
||||
|
||||
int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
|
||||
|
@ -173,6 +173,13 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
int ISD = TLI->InstructionOpcodeToISD(Opcode);
|
||||
assert(ISD && "Invalid opcode");
|
||||
|
||||
// TODO: Allow non-throughput costs that aren't binary.
|
||||
auto AdjustCost = [&CostKind](int Cost) {
|
||||
if (CostKind != TTI::TCK_RecipThroughput)
|
||||
return Cost == 0 ? 0 : 1;
|
||||
return Cost;
|
||||
};
|
||||
|
||||
// Single to/from double precision conversions.
|
||||
static const CostTblEntry NEONFltDblTbl[] = {
|
||||
// Vector fptrunc/fpext conversions.
|
||||
@ -185,14 +192,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
ISD == ISD::FP_EXTEND)) {
|
||||
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Src);
|
||||
if (const auto *Entry = CostTableLookup(NEONFltDblTbl, ISD, LT.second))
|
||||
return LT.first * Entry->Cost;
|
||||
return AdjustCost(LT.first * Entry->Cost);
|
||||
}
|
||||
|
||||
EVT SrcTy = TLI->getValueType(DL, Src);
|
||||
EVT DstTy = TLI->getValueType(DL, Dst);
|
||||
|
||||
if (!SrcTy.isSimple() || !DstTy.isSimple())
|
||||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
|
||||
|
||||
// The extend of a load is free
|
||||
if (I && isa<LoadInst>(I->getOperand(0))) {
|
||||
@ -212,7 +219,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
};
|
||||
if (const auto *Entry = ConvertCostTableLookup(
|
||||
LoadConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
static const TypeConversionCostTblEntry MVELoadConversionTbl[] = {
|
||||
{ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0},
|
||||
@ -226,7 +233,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry =
|
||||
ConvertCostTableLookup(MVELoadConversionTbl, ISD,
|
||||
DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,7 +260,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (auto *Entry = ConvertCostTableLookup(NEONDoubleWidthTbl, UserISD,
|
||||
DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT())) {
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
}
|
||||
|
||||
@ -347,7 +354,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(NEONVectorConversionTbl, ISD,
|
||||
DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
// Scalar float to integer conversions.
|
||||
@ -377,7 +384,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(NEONFloatConversionTbl, ISD,
|
||||
DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
// Scalar integer to float conversions.
|
||||
@ -408,7 +415,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(NEONIntegerConversionTbl,
|
||||
ISD, DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
// MVE extend costs, taken from codegen tests. i8->i16 or i16->i32 is one
|
||||
@ -433,7 +440,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(MVEVectorConversionTbl,
|
||||
ISD, DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost * ST->getMVEVectorCostFactor();
|
||||
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
|
||||
}
|
||||
|
||||
// Scalar integer conversion costs.
|
||||
@ -452,13 +459,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (const auto *Entry = ConvertCostTableLookup(ARMIntegerConversionTbl, ISD,
|
||||
DstTy.getSimpleVT(),
|
||||
SrcTy.getSimpleVT()))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy()
|
||||
? ST->getMVEVectorCostFactor()
|
||||
: 1;
|
||||
return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return AdjustCost(
|
||||
BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
|
||||
}
|
||||
|
||||
int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
|
||||
|
@ -263,7 +263,11 @@ unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
|
||||
|
||||
std::pair<int, MVT> SrcLT = TLI.getTypeLegalizationCost(DL, SrcTy);
|
||||
std::pair<int, MVT> DstLT = TLI.getTypeLegalizationCost(DL, DstTy);
|
||||
return std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
|
||||
unsigned Cost = std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
|
||||
// TODO: Allow non-throughput costs that aren't binary.
|
||||
if (CostKind != TTI::TCK_RecipThroughput)
|
||||
return Cost == 0 ? 0 : 1;
|
||||
return Cost;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
@ -212,7 +212,8 @@ int PPCTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
|
||||
unsigned
|
||||
PPCTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
|
||||
TTI::TargetCostKind CostKind) {
|
||||
if (U->getType()->isVectorTy()) {
|
||||
// We already implement getCastInstrCost and perform the vector adjustment there.
|
||||
if (!isa<CastInst>(U) && U->getType()->isVectorTy()) {
|
||||
// Instructions that need to be split should cost more.
|
||||
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, U->getType());
|
||||
return LT.first * BaseT::getUserCost(U, Operands, CostKind);
|
||||
@ -760,7 +761,11 @@ int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
|
||||
|
||||
int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return vectorCostAdjustment(Cost, Opcode, Dst, Src);
|
||||
Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src);
|
||||
// TODO: Allow non-throughput costs that aren't binary.
|
||||
if (CostKind != TTI::TCK_RecipThroughput)
|
||||
return Cost == 0 ? 0 : 1;
|
||||
return Cost;
|
||||
}
|
||||
|
||||
int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
|
||||
|
@ -691,6 +691,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
|
||||
int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
TTI::TargetCostKind CostKind,
|
||||
const Instruction *I) {
|
||||
// FIXME: Can the logic below also be used for these cost kinds?
|
||||
if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) {
|
||||
int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return BaseCost == 0 ? BaseCost : 1;
|
||||
}
|
||||
|
||||
unsigned DstScalarBits = Dst->getScalarSizeInBits();
|
||||
unsigned SrcScalarBits = Src->getScalarSizeInBits();
|
||||
|
||||
|
@ -1368,6 +1368,13 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
int ISD = TLI->InstructionOpcodeToISD(Opcode);
|
||||
assert(ISD && "Invalid opcode");
|
||||
|
||||
// TODO: Allow non-throughput costs that aren't binary.
|
||||
auto AdjustCost = [&CostKind](int Cost) {
|
||||
if (CostKind != TTI::TCK_RecipThroughput)
|
||||
return Cost == 0 ? 0 : 1;
|
||||
return Cost;
|
||||
};
|
||||
|
||||
// FIXME: Need a better design of the cost table to handle non-simple types of
|
||||
// potential massive combinations (elem_num x src_type x dst_type).
|
||||
|
||||
@ -1969,7 +1976,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (ST->hasSSE2() && !ST->hasAVX()) {
|
||||
if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
|
||||
LTDest.second, LTSrc.second))
|
||||
return LTSrc.first * Entry->Cost;
|
||||
return AdjustCost(LTSrc.first * Entry->Cost);
|
||||
}
|
||||
|
||||
EVT SrcTy = TLI->getValueType(DL, Src);
|
||||
@ -1977,7 +1984,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
|
||||
// The function getSimpleVT only handles simple value types.
|
||||
if (!SrcTy.isSimple() || !DstTy.isSimple())
|
||||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind);
|
||||
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind));
|
||||
|
||||
MVT SimpleSrcTy = SrcTy.getSimpleVT();
|
||||
MVT SimpleDstTy = DstTy.getSimpleVT();
|
||||
@ -1986,59 +1993,59 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
|
||||
if (ST->hasBWI())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
if (ST->hasDQI())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
if (ST->hasAVX512())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512FConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
if (ST->hasBWI())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
if (ST->hasDQI())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
if (ST->hasAVX512())
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
|
||||
if (ST->hasAVX2()) {
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
if (ST->hasAVX()) {
|
||||
if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
if (ST->hasSSE41()) {
|
||||
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
if (ST->hasSSE2()) {
|
||||
if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
|
||||
SimpleDstTy, SimpleSrcTy))
|
||||
return Entry->Cost;
|
||||
return AdjustCost(Entry->Cost);
|
||||
}
|
||||
|
||||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
|
||||
return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
|
||||
}
|
||||
|
||||
int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
|
||||
|
Loading…
Reference in New Issue
Block a user