mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2025-01-31 20:51:52 +01:00
[NVPTX] Added support for half-precision floating point.
Only scalar half-precision operations are supported at the moment. - Adds general support for 'half' type in NVPTX. - fp16 math operations are supported on sm_53+ GPUs only (can be disabled with --nvptx-no-f16-math). - Type conversions to/from fp16 are supported on all GPU variants. - On GPU variants that do not have full fp16 support (or if it's disabled), fp16 operations are promoted to fp32 and results are converted back to fp16 for storage. Differential Revision: https://reviews.llvm.org/D28540 llvm-svn: 291956
This commit is contained in:
parent
fb2ba32d19
commit
f89a861f79
@ -61,6 +61,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const {
|
||||
case 6:
|
||||
OS << "%fd";
|
||||
break;
|
||||
case 7:
|
||||
OS << "%h";
|
||||
break;
|
||||
}
|
||||
|
||||
unsigned VReg = RegNo & 0x0FFFFFFF;
|
||||
@ -247,8 +250,12 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
|
||||
O << "s";
|
||||
else if (Imm == NVPTX::PTXLdStInstCode::Unsigned)
|
||||
O << "u";
|
||||
else
|
||||
else if (Imm == NVPTX::PTXLdStInstCode::Untyped)
|
||||
O << "b";
|
||||
else if (Imm == NVPTX::PTXLdStInstCode::Float)
|
||||
O << "f";
|
||||
else
|
||||
llvm_unreachable("Unknown register type");
|
||||
} else if (!strcmp(Modifier, "vec")) {
|
||||
if (Imm == NVPTX::PTXLdStInstCode::V2)
|
||||
O << ".v2";
|
||||
|
@ -108,7 +108,8 @@ enum AddressSpace {
|
||||
enum FromType {
|
||||
Unsigned = 0,
|
||||
Signed,
|
||||
Float
|
||||
Float,
|
||||
Untyped
|
||||
};
|
||||
enum VecType {
|
||||
Scalar = 1,
|
||||
|
@ -320,6 +320,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
|
||||
|
||||
switch (Cnt->getType()->getTypeID()) {
|
||||
default: report_fatal_error("Unsupported FP type"); break;
|
||||
case Type::HalfTyID:
|
||||
MCOp = MCOperand::createExpr(
|
||||
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
|
||||
break;
|
||||
case Type::FloatTyID:
|
||||
MCOp = MCOperand::createExpr(
|
||||
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
|
||||
@ -357,6 +361,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
|
||||
Ret = (5 << 28);
|
||||
} else if (RC == &NVPTX::Float64RegsRegClass) {
|
||||
Ret = (6 << 28);
|
||||
} else if (RC == &NVPTX::Float16RegsRegClass) {
|
||||
Ret = (7 << 28);
|
||||
} else {
|
||||
report_fatal_error("Bad register class");
|
||||
}
|
||||
@ -396,12 +402,15 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
|
||||
unsigned size = 0;
|
||||
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
|
||||
size = ITy->getBitWidth();
|
||||
if (size < 32)
|
||||
size = 32;
|
||||
} else {
|
||||
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
|
||||
size = Ty->getPrimitiveSizeInBits();
|
||||
}
|
||||
// PTX ABI requires all scalar return values to be at least 32
|
||||
// bits in size. fp16 normally uses .b16 as its storage type in
|
||||
// PTX, so its size must be adjusted here, too.
|
||||
if (size < 32)
|
||||
size = 32;
|
||||
|
||||
O << ".param .b" << size << " func_retval0";
|
||||
} else if (isa<PointerType>(Ty)) {
|
||||
@ -1376,6 +1385,9 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::HalfTyID:
|
||||
// fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
|
||||
return "b16";
|
||||
case Type::FloatTyID:
|
||||
return "f32";
|
||||
case Type::DoubleTyID:
|
||||
@ -1601,6 +1613,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
|
||||
sz = 32;
|
||||
} else if (isa<PointerType>(Ty))
|
||||
sz = thePointerTy.getSizeInBits();
|
||||
else if (Ty->isHalfTy())
|
||||
// PTX ABI requires all scalar parameters to be at least 32
|
||||
// bits in size. fp16 normally uses .b16 as its storage type
|
||||
// in PTX, so its size must be adjusted here, too.
|
||||
sz = 32;
|
||||
else
|
||||
sz = Ty->getPrimitiveSizeInBits();
|
||||
if (isABI)
|
||||
|
@ -42,7 +42,6 @@ FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden,
|
||||
cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
|
||||
cl::init(false));
|
||||
|
||||
|
||||
/// createNVPTXISelDag - This pass converts a legalized DAG into a
|
||||
/// NVPTX-specific DAG, ready for instruction scheduling.
|
||||
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
|
||||
@ -520,6 +519,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
|
||||
case ISD::ADDRSPACECAST:
|
||||
SelectAddrSpaceCast(N);
|
||||
return;
|
||||
case ISD::ConstantFP:
|
||||
if (tryConstantFP16(N))
|
||||
return;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -541,6 +544,19 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
|
||||
}
|
||||
}
|
||||
|
||||
// There's no way to specify FP16 immediates in .f16 ops, so we have to
|
||||
// load them into an .f16 register first.
|
||||
bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
|
||||
if (N->getValueType(0) != MVT::f16)
|
||||
return false;
|
||||
SDValue Val = CurDAG->getTargetConstantFP(
|
||||
cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
|
||||
SDNode *LoadConstF16 =
|
||||
CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
|
||||
ReplaceNode(N, LoadConstF16);
|
||||
return true;
|
||||
}
|
||||
|
||||
static unsigned int getCodeAddrSpace(MemSDNode *N) {
|
||||
const Value *Src = N->getMemOperand()->getValue();
|
||||
|
||||
@ -740,7 +756,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
if ((LD->getExtensionType() == ISD::SEXTLOAD))
|
||||
fromType = NVPTX::PTXLdStInstCode::Signed;
|
||||
else if (ScalarVT.isFloatingPoint())
|
||||
fromType = NVPTX::PTXLdStInstCode::Float;
|
||||
// f16 uses .b16 as its storage type.
|
||||
fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
fromType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
|
||||
@ -766,6 +784,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_avar;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_avar;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_avar;
|
||||
break;
|
||||
@ -794,6 +815,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_asi;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_asi;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_asi;
|
||||
break;
|
||||
@ -823,6 +847,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_ari_64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_ari_64;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_ari_64;
|
||||
break;
|
||||
@ -846,6 +873,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_ari;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_ari;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_ari;
|
||||
break;
|
||||
@ -875,6 +905,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_areg_64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_areg_64;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_areg_64;
|
||||
break;
|
||||
@ -898,6 +931,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::LD_i64_areg;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::LD_f16_areg;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::LD_f32_areg;
|
||||
break;
|
||||
@ -2173,7 +2209,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
unsigned toTypeWidth = ScalarVT.getSizeInBits();
|
||||
unsigned int toType;
|
||||
if (ScalarVT.isFloatingPoint())
|
||||
toType = NVPTX::PTXLdStInstCode::Float;
|
||||
// f16 uses .b16 as its storage type.
|
||||
toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
toType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
|
||||
@ -2200,6 +2238,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_avar;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_avar;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_avar;
|
||||
break;
|
||||
@ -2229,6 +2270,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_asi;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_asi;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_asi;
|
||||
break;
|
||||
@ -2259,6 +2303,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_ari_64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_ari_64;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_ari_64;
|
||||
break;
|
||||
@ -2282,6 +2329,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_ari;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_ari;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_ari;
|
||||
break;
|
||||
@ -2312,6 +2362,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_areg_64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_areg_64;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_areg_64;
|
||||
break;
|
||||
@ -2335,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::ST_i64_areg;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::ST_f16_areg;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::ST_f32_areg;
|
||||
break;
|
||||
@ -2786,6 +2842,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
|
||||
case MVT::i64:
|
||||
Opc = NVPTX::LoadParamMemI64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opc = NVPTX::LoadParamMemF16;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opc = NVPTX::LoadParamMemF32;
|
||||
break;
|
||||
@ -2921,6 +2980,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::StoreRetvalI64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::StoreRetvalF16;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::StoreRetvalF32;
|
||||
break;
|
||||
@ -3054,6 +3116,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::StoreParamI64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
Opcode = NVPTX::StoreParamF16;
|
||||
break;
|
||||
case MVT::f32:
|
||||
Opcode = NVPTX::StoreParamF32;
|
||||
break;
|
||||
|
@ -70,6 +70,7 @@ private:
|
||||
bool tryTextureIntrinsic(SDNode *N);
|
||||
bool trySurfaceIntrinsic(SDNode *N);
|
||||
bool tryBFE(SDNode *N);
|
||||
bool tryConstantFP16(SDNode *N);
|
||||
|
||||
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
|
||||
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
|
||||
|
@ -164,8 +164,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
||||
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
|
||||
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
|
||||
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
|
||||
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
|
||||
|
||||
setOperationAction(ISD::SETCC, MVT::f16,
|
||||
STI.allowFP16Math() ? Legal : Promote);
|
||||
|
||||
// Operations not directly supported by NVPTX.
|
||||
setOperationAction(ISD::SELECT_CC, MVT::f16,
|
||||
STI.allowFP16Math() ? Expand : Promote);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::i1, Expand);
|
||||
@ -173,6 +179,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
||||
setOperationAction(ISD::SELECT_CC, MVT::i16, Expand);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::i64, Expand);
|
||||
setOperationAction(ISD::BR_CC, MVT::f16,
|
||||
STI.allowFP16Math() ? Expand : Promote);
|
||||
setOperationAction(ISD::BR_CC, MVT::f32, Expand);
|
||||
setOperationAction(ISD::BR_CC, MVT::f64, Expand);
|
||||
setOperationAction(ISD::BR_CC, MVT::i1, Expand);
|
||||
@ -259,6 +267,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
||||
// This is legal in NVPTX
|
||||
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
|
||||
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
|
||||
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
|
||||
|
||||
// TRAP can be lowered to PTX trap
|
||||
setOperationAction(ISD::TRAP, MVT::Other, Legal);
|
||||
@ -305,18 +314,36 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
||||
setTargetDAGCombine(ISD::SREM);
|
||||
setTargetDAGCombine(ISD::UREM);
|
||||
|
||||
if (!STI.allowFP16Math()) {
|
||||
// Promote fp16 arithmetic if fp16 hardware isn't available or the
|
||||
// user passed --nvptx-no-fp16-math. The flag is useful because,
|
||||
// although sm_53+ GPUs have some sort of FP16 support in
|
||||
// hardware, only sm_53 and sm_60 have full implementation. Others
|
||||
// only have token amount of hardware and are likely to run faster
|
||||
// by using fp32 units instead.
|
||||
setOperationAction(ISD::FADD, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMUL, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSUB, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMA, MVT::f16, Promote);
|
||||
}
|
||||
|
||||
// Library functions. These default to Expand, but we have instructions
|
||||
// for them.
|
||||
setOperationAction(ISD::FCEIL, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FCEIL, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FCEIL, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FFLOOR, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FFLOOR, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FFLOOR, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FRINT, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FRINT, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FRINT, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FROUND, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FROUND, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FROUND, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FTRUNC, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FTRUNC, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FTRUNC, MVT::f64, Legal);
|
||||
setOperationAction(ISD::FMINNUM, MVT::f32, Legal);
|
||||
@ -324,6 +351,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f32, Legal);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f64, Legal);
|
||||
|
||||
// 'Expand' implements FCOPYSIGN without calling an external library.
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
|
||||
|
||||
// FP16 does not support these nodes in hardware, but we can perform
|
||||
// these ops using single-precision hardware.
|
||||
setOperationAction(ISD::FDIV, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FREM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSQRT, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSIN, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FCOS, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FABS, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMINNAN, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXNAN, MVT::f16, Promote);
|
||||
|
||||
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
|
||||
// No FPOW or FREM in PTX.
|
||||
|
||||
@ -967,19 +1012,21 @@ std::string NVPTXTargetLowering::getPrototype(
|
||||
unsigned size = 0;
|
||||
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
|
||||
size = ITy->getBitWidth();
|
||||
if (size < 32)
|
||||
size = 32;
|
||||
} else {
|
||||
assert(retTy->isFloatingPointTy() &&
|
||||
"Floating point type expected here");
|
||||
size = retTy->getPrimitiveSizeInBits();
|
||||
}
|
||||
// PTX ABI requires all scalar return values to be at least 32
|
||||
// bits in size. fp16 normally uses .b16 as its storage type in
|
||||
// PTX, so its size must be adjusted here, too.
|
||||
if (size < 32)
|
||||
size = 32;
|
||||
|
||||
O << ".param .b" << size << " _";
|
||||
} else if (isa<PointerType>(retTy)) {
|
||||
O << ".param .b" << PtrVT.getSizeInBits() << " _";
|
||||
} else if ((retTy->getTypeID() == Type::StructTyID) ||
|
||||
isa<VectorType>(retTy)) {
|
||||
} else if (retTy->isAggregateType() || retTy->isVectorTy()) {
|
||||
auto &DL = CS->getCalledFunction()->getParent()->getDataLayout();
|
||||
O << ".param .align " << retAlignment << " .b8 _["
|
||||
<< DL.getTypeAllocSize(retTy) << "]";
|
||||
@ -1018,7 +1065,7 @@ std::string NVPTXTargetLowering::getPrototype(
|
||||
OIdx += len - 1;
|
||||
continue;
|
||||
}
|
||||
// i8 types in IR will be i16 types in SDAG
|
||||
// i8 types in IR will be i16 types in SDAG
|
||||
assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
|
||||
(getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
|
||||
"type mismatch between callee prototype and arguments");
|
||||
@ -1028,8 +1075,13 @@ std::string NVPTXTargetLowering::getPrototype(
|
||||
sz = cast<IntegerType>(Ty)->getBitWidth();
|
||||
if (sz < 32)
|
||||
sz = 32;
|
||||
} else if (isa<PointerType>(Ty))
|
||||
} else if (isa<PointerType>(Ty)) {
|
||||
sz = PtrVT.getSizeInBits();
|
||||
} else if (Ty->isHalfTy())
|
||||
// PTX ABI requires all scalar parameters to be at least 32
|
||||
// bits in size. fp16 normally uses .b16 as its storage type
|
||||
// in PTX, so its size must be adjusted here, too.
|
||||
sz = 32;
|
||||
else
|
||||
sz = Ty->getPrimitiveSizeInBits();
|
||||
O << ".param .b" << sz << " ";
|
||||
@ -1340,7 +1392,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
needExtend = true;
|
||||
if (sz < 32)
|
||||
sz = 32;
|
||||
}
|
||||
} else if (VT.isFloatingPoint() && sz < 32)
|
||||
// PTX ABI requires all scalar parameters to be at least 32
|
||||
// bits in size. fp16 normally uses .b16 as its storage type
|
||||
// in PTX, so its size must be adjusted here, too.
|
||||
sz = 32;
|
||||
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
SDValue DeclareParamOps[] = { Chain,
|
||||
DAG.getConstant(paramCount, dl, MVT::i32),
|
||||
@ -1952,12 +2008,15 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
|
||||
|
||||
SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
|
||||
EVT ValVT = Op.getOperand(1).getValueType();
|
||||
if (ValVT == MVT::i1)
|
||||
switch (ValVT.getSimpleVT().SimpleTy) {
|
||||
case MVT::i1:
|
||||
return LowerSTOREi1(Op, DAG);
|
||||
else if (ValVT.isVector())
|
||||
return LowerSTOREVector(Op, DAG);
|
||||
else
|
||||
return SDValue();
|
||||
default:
|
||||
if (ValVT.isVector())
|
||||
return LowerSTOREVector(Op, DAG);
|
||||
else
|
||||
return SDValue();
|
||||
}
|
||||
}
|
||||
|
||||
SDValue
|
||||
@ -2557,8 +2616,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
|
||||
// specifically not for aggregates.
|
||||
TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
|
||||
TheStoreType = MVT::i32;
|
||||
}
|
||||
else if (TmpVal.getValueSizeInBits() < 16)
|
||||
} else if (RetTy->isHalfTy()) {
|
||||
TheStoreType = MVT::f16;
|
||||
} else if (TmpVal.getValueSizeInBits() < 16)
|
||||
TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
|
||||
|
||||
SDValue Ops[] = {
|
||||
|
@ -528,6 +528,7 @@ private:
|
||||
|
||||
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerSTOREf16(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
||||
SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
@ -52,6 +52,9 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
|
||||
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
|
||||
Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
|
||||
: NVPTX::BITCONVERT_64_F2I);
|
||||
} else if (DestRC == &NVPTX::Float16RegsRegClass) {
|
||||
Op = (SrcRC == &NVPTX::Float16RegsRegClass ? NVPTX::FMOV16rr
|
||||
: NVPTX::BITCONVERT_16_I2F);
|
||||
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
|
||||
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
|
||||
: NVPTX::BITCONVERT_32_I2F);
|
||||
|
@ -18,6 +18,10 @@ let hasSideEffects = 0 in {
|
||||
def NOP : NVPTXInst<(outs), (ins), "", []>;
|
||||
}
|
||||
|
||||
let OperandType = "OPERAND_IMMEDIATE" in {
|
||||
def f16imm : Operand<f16>;
|
||||
}
|
||||
|
||||
// List of vector specific properties
|
||||
def isVecLD : VecInstTypeEnum<1>;
|
||||
def isVecST : VecInstTypeEnum<2>;
|
||||
@ -149,6 +153,7 @@ def true : Predicate<"true">;
|
||||
|
||||
def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
|
||||
|
||||
def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Some Common Instruction Class Templates
|
||||
@ -240,11 +245,11 @@ multiclass F3<string OpcStr, SDNode OpNode> {
|
||||
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
|
||||
}
|
||||
|
||||
// Template for instructions which take three fp64 or fp32 args. The
|
||||
// Template for instructions which take three FP args. The
|
||||
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
|
||||
//
|
||||
// Also defines ftz (flush subnormal inputs and results to sign-preserving
|
||||
// zero) variants for fp32 functions.
|
||||
// zero) variants for fp32/fp16 functions.
|
||||
//
|
||||
// This multiclass should be used for nodes that can be folded to make fma ops.
|
||||
// In this case, we use the ".rn" variant when FMA is disabled, as this behaves
|
||||
@ -287,6 +292,19 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
|
||||
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
|
||||
Requires<[allowFMA]>;
|
||||
|
||||
def f16rr_ftz :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math, allowFMA, doF32FTZ]>;
|
||||
def f16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math, allowFMA]>;
|
||||
|
||||
// These have strange names so we don't perturb existing mir tests.
|
||||
def _rnf64rr :
|
||||
NVPTXInst<(outs Float64Regs:$dst),
|
||||
@ -324,6 +342,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
|
||||
!strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
|
||||
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
|
||||
Requires<[noFMA]>;
|
||||
def _rnf16rr_ftz :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math, noFMA, doF32FTZ]>;
|
||||
def _rnf16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math, noFMA]>;
|
||||
}
|
||||
|
||||
// Template for operations which take two f32 or f64 operands. Provides three
|
||||
@ -375,11 +405,6 @@ let hasSideEffects = 0 in {
|
||||
(ins Int16Regs:$src, CvtMode:$mode),
|
||||
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
|
||||
FromName, ".u16\t$dst, $src;"), []>;
|
||||
def _f16 :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins Int16Regs:$src, CvtMode:$mode),
|
||||
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
|
||||
FromName, ".f16\t$dst, $src;"), []>;
|
||||
def _s32 :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins Int32Regs:$src, CvtMode:$mode),
|
||||
@ -400,6 +425,11 @@ let hasSideEffects = 0 in {
|
||||
(ins Int64Regs:$src, CvtMode:$mode),
|
||||
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
|
||||
FromName, ".u64\t$dst, $src;"), []>;
|
||||
def _f16 :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins Float16Regs:$src, CvtMode:$mode),
|
||||
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
|
||||
FromName, ".f16\t$dst, $src;"), []>;
|
||||
def _f32 :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins Float32Regs:$src, CvtMode:$mode),
|
||||
@ -417,11 +447,11 @@ let hasSideEffects = 0 in {
|
||||
defm CVT_u8 : CVT_FROM_ALL<"u8", Int16Regs>;
|
||||
defm CVT_s16 : CVT_FROM_ALL<"s16", Int16Regs>;
|
||||
defm CVT_u16 : CVT_FROM_ALL<"u16", Int16Regs>;
|
||||
defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>;
|
||||
defm CVT_s32 : CVT_FROM_ALL<"s32", Int32Regs>;
|
||||
defm CVT_u32 : CVT_FROM_ALL<"u32", Int32Regs>;
|
||||
defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>;
|
||||
defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>;
|
||||
defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>;
|
||||
defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>;
|
||||
defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>;
|
||||
|
||||
@ -749,6 +779,15 @@ def DoubleConst1 : PatLeaf<(fpimm), [{
|
||||
N->getValueAPF().convertToDouble() == 1.0;
|
||||
}]>;
|
||||
|
||||
// Loads FP16 constant into a register.
|
||||
//
|
||||
// ptxas does not have hex representation for fp16, so we can't use
|
||||
// fp16 immediate values in .f16 instructions. Instead we have to load
|
||||
// the constant into a register using mov.b16.
|
||||
def LOAD_CONST_F16 :
|
||||
NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a),
|
||||
"mov.b16 \t$dst, $a;", []>;
|
||||
|
||||
defm FADD : F3_fma_component<"add", fadd>;
|
||||
defm FSUB : F3_fma_component<"sub", fsub>;
|
||||
defm FMUL : F3_fma_component<"mul", fmul>;
|
||||
@ -943,6 +982,15 @@ multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred>
|
||||
Requires<[Pred]>;
|
||||
}
|
||||
|
||||
multiclass FMA_F16<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
|
||||
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
|
||||
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
|
||||
[(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
|
||||
Requires<[useFP16Math, Pred]>;
|
||||
}
|
||||
|
||||
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", Float16Regs, f16imm, doF32FTZ>;
|
||||
defm FMA16 : FMA_F16<"fma.rn.f16", Float16Regs, f16imm, true>;
|
||||
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
|
||||
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, true>;
|
||||
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, true>;
|
||||
@ -1320,6 +1368,11 @@ defm SETP_s64 : SETP<"s64", Int64Regs, i64imm>;
|
||||
defm SETP_u64 : SETP<"u64", Int64Regs, i64imm>;
|
||||
defm SETP_f32 : SETP<"f32", Float32Regs, f32imm>;
|
||||
defm SETP_f64 : SETP<"f64", Float64Regs, f64imm>;
|
||||
def SETP_f16rr :
|
||||
NVPTXInst<(outs Int1Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b, CmpMode:$cmp),
|
||||
"setp${cmp:base}${cmp:ftz}.f16 $dst, $a, $b;",
|
||||
[]>, Requires<[useFP16Math]>;
|
||||
|
||||
// FIXME: This doesn't appear to be correct. The "set" mnemonic has the form
|
||||
// "set.CmpOp{.ftz}.dtype.stype", where dtype is the type of the destination
|
||||
@ -1348,6 +1401,7 @@ defm SET_u32 : SET<"u32", Int32Regs, i32imm>;
|
||||
defm SET_b64 : SET<"b64", Int64Regs, i64imm>;
|
||||
defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
|
||||
defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
|
||||
defm SET_f16 : SET<"f16", Float16Regs, f16imm>;
|
||||
defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
|
||||
defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
|
||||
|
||||
@ -1411,6 +1465,7 @@ defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
|
||||
defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>;
|
||||
defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
|
||||
defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
|
||||
defm SELP_f16 : SELP_PATTERN<"b16", Float16Regs, f16imm, fpimm>;
|
||||
defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>;
|
||||
defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>;
|
||||
|
||||
@ -1475,6 +1530,9 @@ let IsSimpleMove=1, hasSideEffects=0 in {
|
||||
def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss),
|
||||
"mov.u64 \t$dst, $sss;", []>;
|
||||
|
||||
def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src),
|
||||
// We have to use .b16 here as there's no mov.f16.
|
||||
"mov.b16 \t$dst, $src;", []>;
|
||||
def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
|
||||
"mov.f32 \t$dst, $src;", []>;
|
||||
def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src),
|
||||
@ -1636,6 +1694,26 @@ def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)),
|
||||
|
||||
|
||||
multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
|
||||
// f16 -> pred
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
// f32 -> pred
|
||||
def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)),
|
||||
(SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
|
||||
@ -1661,6 +1739,26 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
|
||||
def : Pat<(i1 (OpNode fpimm:$a, Float64Regs:$b)),
|
||||
(SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>;
|
||||
|
||||
// f16 -> i32
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
(SET_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
(SET_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
(SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
(SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
// f32 -> i32
|
||||
def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)),
|
||||
(SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
|
||||
@ -1944,6 +2042,7 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">;
|
||||
def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">;
|
||||
def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">;
|
||||
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
|
||||
def LoadParamMemF16 : LoadParamMemInst<Float16Regs, ".b16">;
|
||||
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
|
||||
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
|
||||
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
|
||||
@ -1964,6 +2063,7 @@ def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">;
|
||||
def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">;
|
||||
def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">;
|
||||
|
||||
def StoreParamF16 : StoreParamInst<Float16Regs, ".b16">;
|
||||
def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
|
||||
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
|
||||
def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">;
|
||||
@ -1984,6 +2084,7 @@ def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">;
|
||||
|
||||
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">;
|
||||
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
|
||||
def StoreRetvalF16 : StoreRetvalInst<Float16Regs, ".b16">;
|
||||
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
|
||||
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
|
||||
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">;
|
||||
@ -2071,6 +2172,7 @@ def MoveParamI16 :
|
||||
[(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>;
|
||||
def MoveParamF64 : MoveParamInst<Float64Regs, ".f64">;
|
||||
def MoveParamF32 : MoveParamInst<Float32Regs, ".f32">;
|
||||
def MoveParamF16 : MoveParamInst<Float16Regs, ".f16">;
|
||||
|
||||
class PseudoUseParamInst<NVPTXRegClass regclass> :
|
||||
NVPTXInst<(outs), (ins regclass:$src),
|
||||
@ -2131,6 +2233,7 @@ let mayLoad=1, hasSideEffects=0 in {
|
||||
defm LD_i16 : LD<Int16Regs>;
|
||||
defm LD_i32 : LD<Int32Regs>;
|
||||
defm LD_i64 : LD<Int64Regs>;
|
||||
defm LD_f16 : LD<Float16Regs>;
|
||||
defm LD_f32 : LD<Float32Regs>;
|
||||
defm LD_f64 : LD<Float64Regs>;
|
||||
}
|
||||
@ -2179,6 +2282,7 @@ let mayStore=1, hasSideEffects=0 in {
|
||||
defm ST_i16 : ST<Int16Regs>;
|
||||
defm ST_i32 : ST<Int32Regs>;
|
||||
defm ST_i64 : ST<Int64Regs>;
|
||||
defm ST_f16 : ST<Float16Regs>;
|
||||
defm ST_f32 : ST<Float32Regs>;
|
||||
defm ST_f64 : ST<Float64Regs>;
|
||||
}
|
||||
@ -2371,6 +2475,8 @@ class F_BITCONVERT<string SzStr, NVPTXRegClass regclassIn,
|
||||
!strconcat("mov.b", !strconcat(SzStr, " \t $d, $a;")),
|
||||
[(set regclassOut:$d, (bitconvert regclassIn:$a))]>;
|
||||
|
||||
def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>;
|
||||
def BITCONVERT_16_F2I : F_BITCONVERT<"16", Float16Regs, Int16Regs>;
|
||||
def BITCONVERT_32_I2F : F_BITCONVERT<"32", Int32Regs, Float32Regs>;
|
||||
def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>;
|
||||
def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>;
|
||||
@ -2380,6 +2486,26 @@ def BITCONVERT_64_F2I : F_BITCONVERT<"64", Float64Regs, Int64Regs>;
|
||||
// we cannot specify floating-point literals in isel patterns. Therefore, we
|
||||
// use an integer selp to select either 1 or 0 and then cvt to floating-point.
|
||||
|
||||
// sint -> f16
|
||||
def : Pat<(f16 (sint_to_fp Int1Regs:$a)),
|
||||
(CVT_f16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
|
||||
def : Pat<(f16 (sint_to_fp Int16Regs:$a)),
|
||||
(CVT_f16_s16 Int16Regs:$a, CvtRN)>;
|
||||
def : Pat<(f16 (sint_to_fp Int32Regs:$a)),
|
||||
(CVT_f16_s32 Int32Regs:$a, CvtRN)>;
|
||||
def : Pat<(f16 (sint_to_fp Int64Regs:$a)),
|
||||
(CVT_f16_s64 Int64Regs:$a, CvtRN)>;
|
||||
|
||||
// uint -> f16
|
||||
def : Pat<(f16 (uint_to_fp Int1Regs:$a)),
|
||||
(CVT_f16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
|
||||
def : Pat<(f16 (uint_to_fp Int16Regs:$a)),
|
||||
(CVT_f16_u16 Int16Regs:$a, CvtRN)>;
|
||||
def : Pat<(f16 (uint_to_fp Int32Regs:$a)),
|
||||
(CVT_f16_u32 Int32Regs:$a, CvtRN)>;
|
||||
def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
|
||||
(CVT_f16_u64 Int64Regs:$a, CvtRN)>;
|
||||
|
||||
// sint -> f32
|
||||
def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
|
||||
(CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
|
||||
@ -2421,6 +2547,38 @@ def : Pat<(f64 (uint_to_fp Int64Regs:$a)),
|
||||
(CVT_f64_u64 Int64Regs:$a, CvtRN)>;
|
||||
|
||||
|
||||
// f16 -> sint
|
||||
def : Pat<(i1 (fp_to_sint Float16Regs:$a)),
|
||||
(SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>;
|
||||
def : Pat<(i16 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i16 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s16_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i32 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i32 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s32_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i64 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i64 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s64_f16 Float16Regs:$a, CvtRZI)>;
|
||||
|
||||
// f16 -> uint
|
||||
def : Pat<(i1 (fp_to_uint Float16Regs:$a)),
|
||||
(SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>;
|
||||
def : Pat<(i16 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i16 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u16_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i32 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i32 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u32_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i64 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i64 (fp_to_uint Float16Regs:$a)),
|
||||
(CVT_u64_f16 Float16Regs:$a, CvtRZI)>;
|
||||
|
||||
// f32 -> sint
|
||||
def : Pat<(i1 (fp_to_sint Float32Regs:$a)),
|
||||
(SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>;
|
||||
@ -2650,12 +2808,36 @@ def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>;
|
||||
def : Pat<(ctpop Int16Regs:$a),
|
||||
(CVT_u16_u32 (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), CvtNONE)>;
|
||||
|
||||
// fpround f32 -> f16
|
||||
def : Pat<(f16 (fpround Float32Regs:$a)),
|
||||
(CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f16 (fpround Float32Regs:$a)),
|
||||
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
|
||||
|
||||
// fpround f64 -> f16
|
||||
def : Pat<(f16 (fpround Float64Regs:$a)),
|
||||
(CVT_f16_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f16 (fpround Float64Regs:$a)),
|
||||
(CVT_f16_f64 Float64Regs:$a, CvtRN)>;
|
||||
|
||||
// fpround f64 -> f32
|
||||
def : Pat<(f32 (fpround Float64Regs:$a)),
|
||||
(CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f32 (fpround Float64Regs:$a)),
|
||||
(CVT_f32_f64 Float64Regs:$a, CvtRN)>;
|
||||
|
||||
// fpextend f16 -> f32
|
||||
def : Pat<(f32 (fpextend Float16Regs:$a)),
|
||||
(CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f32 (fpextend Float16Regs:$a)),
|
||||
(CVT_f32_f16 Float16Regs:$a, CvtNONE)>;
|
||||
|
||||
// fpextend f16 -> f64
|
||||
def : Pat<(f64 (fpextend Float16Regs:$a)),
|
||||
(CVT_f64_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f64 (fpextend Float16Regs:$a)),
|
||||
(CVT_f64_f16 Float16Regs:$a, CvtNONE)>;
|
||||
|
||||
// fpextend f32 -> f64
|
||||
def : Pat<(f64 (fpextend Float32Regs:$a)),
|
||||
(CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
|
||||
@ -2667,6 +2849,10 @@ def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone,
|
||||
|
||||
// fceil, ffloor, fround, ftrunc.
|
||||
|
||||
def : Pat<(fceil Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(fceil Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(fceil Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(fceil Float32Regs:$a),
|
||||
@ -2674,6 +2860,10 @@ def : Pat<(fceil Float32Regs:$a),
|
||||
def : Pat<(fceil Float64Regs:$a),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRPI)>;
|
||||
|
||||
def : Pat<(ffloor Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(ffloor Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(ffloor Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(ffloor Float32Regs:$a),
|
||||
@ -2681,6 +2871,10 @@ def : Pat<(ffloor Float32Regs:$a),
|
||||
def : Pat<(ffloor Float64Regs:$a),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
|
||||
|
||||
def : Pat<(fround Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f16 (fround Float16Regs:$a)),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(fround Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f32 (fround Float32Regs:$a)),
|
||||
@ -2688,6 +2882,10 @@ def : Pat<(f32 (fround Float32Regs:$a)),
|
||||
def : Pat<(f64 (fround Float64Regs:$a)),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
|
||||
|
||||
def : Pat<(ftrunc Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(ftrunc Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(ftrunc Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(ftrunc Float32Regs:$a),
|
||||
@ -2699,6 +2897,10 @@ def : Pat<(ftrunc Float64Regs:$a),
|
||||
// strictly correct, because it causes us to ignore the rounding mode. But it
|
||||
// matches what CUDA's "libm" does.
|
||||
|
||||
def : Pat<(fnearbyint Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(fnearbyint Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(fnearbyint Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(fnearbyint Float32Regs:$a),
|
||||
@ -2706,6 +2908,10 @@ def : Pat<(fnearbyint Float32Regs:$a),
|
||||
def : Pat<(fnearbyint Float64Regs:$a),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
|
||||
|
||||
def : Pat<(frint Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(frint Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(frint Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(frint Float32Regs:$a),
|
||||
|
@ -803,49 +803,13 @@ def : Pat<(int_nvvm_ull2d_rp Int64Regs:$a),
|
||||
(CVT_f64_u64 Int64Regs:$a, CvtRP)>;
|
||||
|
||||
|
||||
// FIXME: Ideally, we could use these patterns instead of the scope-creating
|
||||
// patterns, but ptxas does not like these since .s16 is not compatible with
|
||||
// .f16. The solution is to use .bXX for all integer register types, but we
|
||||
// are not there yet.
|
||||
//def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
|
||||
// (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>;
|
||||
//def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
|
||||
// (CVT_f16_f32 Float32Regs:$a, CvtRN)>;
|
||||
//
|
||||
//def : Pat<(int_nvvm_h2f Int16Regs:$a),
|
||||
// (CVT_f32_f16 Int16Regs:$a, CvtNONE)>;
|
||||
def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
|
||||
(BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ))>;
|
||||
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
|
||||
(BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN))>;
|
||||
|
||||
def INT_NVVM_F2H_RN_FTZ : F_MATH_1<!strconcat("{{\n\t",
|
||||
!strconcat(".reg .b16 %temp;\n\t",
|
||||
!strconcat("cvt.rn.ftz.f16.f32 \t%temp, $src0;\n\t",
|
||||
!strconcat("mov.b16 \t$dst, %temp;\n",
|
||||
"}}")))),
|
||||
Int16Regs, Float32Regs, int_nvvm_f2h_rn_ftz>;
|
||||
def INT_NVVM_F2H_RN : F_MATH_1<!strconcat("{{\n\t",
|
||||
!strconcat(".reg .b16 %temp;\n\t",
|
||||
!strconcat("cvt.rn.f16.f32 \t%temp, $src0;\n\t",
|
||||
!strconcat("mov.b16 \t$dst, %temp;\n",
|
||||
"}}")))),
|
||||
Int16Regs, Float32Regs, int_nvvm_f2h_rn>;
|
||||
|
||||
def INT_NVVM_H2F : F_MATH_1<!strconcat("{{\n\t",
|
||||
!strconcat(".reg .b16 %temp;\n\t",
|
||||
!strconcat("mov.b16 \t%temp, $src0;\n\t",
|
||||
!strconcat("cvt.f32.f16 \t$dst, %temp;\n\t",
|
||||
"}}")))),
|
||||
Float32Regs, Int16Regs, int_nvvm_h2f>;
|
||||
|
||||
def : Pat<(f32 (f16_to_fp Int16Regs:$a)),
|
||||
(CVT_f32_f16 Int16Regs:$a, CvtNONE)>;
|
||||
def : Pat<(i16 (fp_to_f16 Float32Regs:$a)),
|
||||
(CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(i16 (fp_to_f16 Float32Regs:$a)),
|
||||
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
|
||||
|
||||
def : Pat<(f64 (f16_to_fp Int16Regs:$a)),
|
||||
(CVT_f64_f16 Int16Regs:$a, CvtNONE)>;
|
||||
def : Pat<(i16 (fp_to_f16 Float64Regs:$a)),
|
||||
(CVT_f16_f64 Float64Regs:$a, CvtRN)>;
|
||||
def : Pat<(int_nvvm_h2f Int16Regs:$a),
|
||||
(CVT_f32_f16 (BITCONVERT_16_I2F Int16Regs:$a), CvtNONE)>;
|
||||
|
||||
//
|
||||
// Bitcast
|
||||
|
@ -27,6 +27,13 @@ void NVPTXFloatMCExpr::printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const {
|
||||
|
||||
switch (Kind) {
|
||||
default: llvm_unreachable("Invalid kind!");
|
||||
case VK_NVPTX_HALF_PREC_FLOAT:
|
||||
// ptxas does not have a way to specify half-precision floats.
|
||||
// Instead we have to print and load fp16 constants as .b16
|
||||
OS << "0x";
|
||||
NumHex = 4;
|
||||
APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
|
||||
break;
|
||||
case VK_NVPTX_SINGLE_PREC_FLOAT:
|
||||
OS << "0f";
|
||||
NumHex = 8;
|
||||
|
@ -22,8 +22,9 @@ class NVPTXFloatMCExpr : public MCTargetExpr {
|
||||
public:
|
||||
enum VariantKind {
|
||||
VK_NVPTX_None,
|
||||
VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
|
||||
VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision
|
||||
VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision
|
||||
VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
|
||||
VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision
|
||||
};
|
||||
|
||||
private:
|
||||
@ -40,6 +41,11 @@ public:
|
||||
static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt,
|
||||
MCContext &Ctx);
|
||||
|
||||
static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt,
|
||||
MCContext &Ctx) {
|
||||
return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx);
|
||||
}
|
||||
|
||||
static const NVPTXFloatMCExpr *createConstantFPSingle(const APFloat &Flt,
|
||||
MCContext &Ctx) {
|
||||
return create(VK_NVPTX_SINGLE_PREC_FLOAT, Flt, Ctx);
|
||||
|
@ -27,12 +27,17 @@ using namespace llvm;
|
||||
|
||||
namespace llvm {
|
||||
std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
|
||||
if (RC == &NVPTX::Float32RegsRegClass) {
|
||||
if (RC == &NVPTX::Float32RegsRegClass)
|
||||
return ".f32";
|
||||
}
|
||||
if (RC == &NVPTX::Float64RegsRegClass) {
|
||||
if (RC == &NVPTX::Float16RegsRegClass)
|
||||
// Ideally fp16 registers should be .f16, but this syntax is only
|
||||
// supported on sm_53+. On the other hand, .b16 registers are
|
||||
// accepted for all supported fp16 instructions on all GPU
|
||||
// variants, so we can use them instead.
|
||||
return ".b16";
|
||||
if (RC == &NVPTX::Float64RegsRegClass)
|
||||
return ".f64";
|
||||
} else if (RC == &NVPTX::Int64RegsRegClass) {
|
||||
if (RC == &NVPTX::Int64RegsRegClass)
|
||||
// We use untyped (.b) integer registers here as NVCC does.
|
||||
// Correctness of generated code does not depend on register type,
|
||||
// but using .s/.u registers runs into ptxas bug that prevents
|
||||
@ -52,40 +57,35 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
|
||||
// add.f16v2 rb32,rb32,rb32; // OK
|
||||
// add.f16v2 rs32,rs32,rs32; // OK
|
||||
return ".b64";
|
||||
} else if (RC == &NVPTX::Int32RegsRegClass) {
|
||||
if (RC == &NVPTX::Int32RegsRegClass)
|
||||
return ".b32";
|
||||
} else if (RC == &NVPTX::Int16RegsRegClass) {
|
||||
if (RC == &NVPTX::Int16RegsRegClass)
|
||||
return ".b16";
|
||||
} else if (RC == &NVPTX::Int1RegsRegClass) {
|
||||
if (RC == &NVPTX::Int1RegsRegClass)
|
||||
return ".pred";
|
||||
} else if (RC == &NVPTX::SpecialRegsRegClass) {
|
||||
if (RC == &NVPTX::SpecialRegsRegClass)
|
||||
return "!Special!";
|
||||
} else {
|
||||
return "INTERNAL";
|
||||
}
|
||||
return "";
|
||||
return "INTERNAL";
|
||||
}
|
||||
|
||||
std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
|
||||
if (RC == &NVPTX::Float32RegsRegClass) {
|
||||
if (RC == &NVPTX::Float32RegsRegClass)
|
||||
return "%f";
|
||||
}
|
||||
if (RC == &NVPTX::Float64RegsRegClass) {
|
||||
if (RC == &NVPTX::Float16RegsRegClass)
|
||||
return "%h";
|
||||
if (RC == &NVPTX::Float64RegsRegClass)
|
||||
return "%fd";
|
||||
} else if (RC == &NVPTX::Int64RegsRegClass) {
|
||||
if (RC == &NVPTX::Int64RegsRegClass)
|
||||
return "%rd";
|
||||
} else if (RC == &NVPTX::Int32RegsRegClass) {
|
||||
if (RC == &NVPTX::Int32RegsRegClass)
|
||||
return "%r";
|
||||
} else if (RC == &NVPTX::Int16RegsRegClass) {
|
||||
if (RC == &NVPTX::Int16RegsRegClass)
|
||||
return "%rs";
|
||||
} else if (RC == &NVPTX::Int1RegsRegClass) {
|
||||
if (RC == &NVPTX::Int1RegsRegClass)
|
||||
return "%p";
|
||||
} else if (RC == &NVPTX::SpecialRegsRegClass) {
|
||||
if (RC == &NVPTX::SpecialRegsRegClass)
|
||||
return "!Special!";
|
||||
} else {
|
||||
return "INTERNAL";
|
||||
}
|
||||
return "";
|
||||
return "INTERNAL";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,7 @@ foreach i = 0-4 in {
|
||||
def RS#i : NVPTXReg<"%rs"#i>; // 16-bit
|
||||
def R#i : NVPTXReg<"%r"#i>; // 32-bit
|
||||
def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
|
||||
def H#i : NVPTXReg<"%h"#i>; // 16-bit float
|
||||
def F#i : NVPTXReg<"%f"#i>; // 32-bit float
|
||||
def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float
|
||||
|
||||
@ -57,6 +58,7 @@ def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
|
||||
def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>;
|
||||
def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4))>;
|
||||
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4))>;
|
||||
def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>;
|
||||
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
|
||||
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
|
||||
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
|
||||
|
@ -23,6 +23,11 @@ using namespace llvm;
|
||||
#define GET_SUBTARGETINFO_CTOR
|
||||
#include "NVPTXGenSubtargetInfo.inc"
|
||||
|
||||
static cl::opt<bool>
|
||||
NoF16Math("nvptx-no-f16-math", cl::ZeroOrMore, cl::Hidden,
|
||||
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
|
||||
cl::init(false));
|
||||
|
||||
// Pin the vtable to this file.
|
||||
void NVPTXSubtarget::anchor() {}
|
||||
|
||||
@ -57,3 +62,7 @@ bool NVPTXSubtarget::hasImageHandles() const {
|
||||
// Disabled, otherwise
|
||||
return false;
|
||||
}
|
||||
|
||||
bool NVPTXSubtarget::allowFP16Math() const {
|
||||
return hasFP16Math() && NoF16Math == false;
|
||||
}
|
||||
|
@ -101,6 +101,8 @@ public:
|
||||
inline bool hasROT32() const { return hasHWROT32() || hasSWROT32(); }
|
||||
inline bool hasROT64() const { return SmVersion >= 20; }
|
||||
bool hasImageHandles() const;
|
||||
bool hasFP16Math() const { return SmVersion >= 53; }
|
||||
bool allowFP16Math() const;
|
||||
|
||||
unsigned int getSmVersion() const { return SmVersion; }
|
||||
std::string getTargetName() const { return TargetName; }
|
||||
|
1034
test/CodeGen/NVPTX/f16-instructions.ll
Normal file
1034
test/CodeGen/NVPTX/f16-instructions.ll
Normal file
File diff suppressed because it is too large
Load Diff
@ -2,8 +2,8 @@
|
||||
|
||||
define void @test_load_store(half addrspace(1)* %in, half addrspace(1)* %out) {
|
||||
; CHECK-LABEL: @test_load_store
|
||||
; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]]
|
||||
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
|
||||
%val = load half, half addrspace(1)* %in
|
||||
store half %val, half addrspace(1) * %out
|
||||
ret void
|
||||
@ -11,8 +11,8 @@ define void @test_load_store(half addrspace(1)* %in, half addrspace(1)* %out) {
|
||||
|
||||
define void @test_bitcast_from_half(half addrspace(1)* %in, i16 addrspace(1)* %out) {
|
||||
; CHECK-LABEL: @test_bitcast_from_half
|
||||
; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]]
|
||||
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
|
||||
%val = load half, half addrspace(1) * %in
|
||||
%val_int = bitcast half %val to i16
|
||||
store i16 %val_int, i16 addrspace(1)* %out
|
||||
|
Loading…
x
Reference in New Issue
Block a user