1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-19 11:02:59 +02:00

[NVPTX] Implemented wmma intrinsics and instructions.

WMMA = "Warp Level Matrix Multiply-Accumulate".
These are the new instructions introduced in PTX6.0 and available
on sm_70 GPUs.

Differential Revision: https://reviews.llvm.org/D38645

llvm-svn: 315601
This commit is contained in:
Artem Belevich 2017-10-12 18:27:55 +00:00
parent 9f8f153184
commit 848056d8ad
6 changed files with 1192 additions and 0 deletions

View File

@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
[IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
//
// WMMA instructions
//
// WMMA.LOAD
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
string Type, LLVMType regty, int WithStride>
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
[], // Properties must be set during instantiation.
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
#Space
#!if(WithStride,".stride","")
#"."#Type>;
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
}
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
string Type, LLVMType regty> {
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
}
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
}
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
// outside of the multiclass works.
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
ReadOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
}
// WMMA.STORE.D
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
list<LLVMType>Empty=[]>
: Intrinsic<[],
!listconcat(
[llvm_ptr_ty],
!if(!eq(Type,"f16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_i32_ty], Empty)),
[], // Properties must be set during instantiation.
"llvm.nvvm.wmma.store.d.sync."#Layout
#".m16n16k16"#Space
#!if(WithStride,".stride","")
#"."#Type>;
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
}
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
}
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
}
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
WriteOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
}
// WMMA.MMA
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty,
string Satfinite = "">
: Intrinsic<!if(!eq(DType,"f16"),
[d_regty, d_regty, d_regty, d_regty],
[d_regty, d_regty, d_regty, d_regty,
d_regty, d_regty, d_regty, d_regty]),
!listconcat(
[// A
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
// B
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
!if(!eq(CType,"f16"),
[c_regty, c_regty, c_regty, c_regty],
[c_regty, c_regty, c_regty, c_regty,
c_regty, c_regty, c_regty, c_regty])),
[IntrNoMem],
"llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
#".m16n16k16."#DType#"."#CType#Satfinite>;
multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty> {
def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_regty,
CType, c_regty>;
def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_regty,
CType, c_regty,".satfinite">;
}
multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
string DType, LLVMType d_regty> {
defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
"f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
"f32", llvm_float_ty>;
}
multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
}
multiclass NVVM_WMMA_MMA_A<string ALayout> {
defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
}
defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
} // let TargetPrefix = "nvvm"

View File

@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
SelectCode(N);
}
// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
// look up the intrinsic ID of particular variant.
enum WmmaVariant {
WMMA_VARIANT_ARI64,
WMMA_VARIANT_ARI64_STRIDE,
WMMA_VARIANT_AVAR,
WMMA_VARIANT_AVAR_STRIDE,
};
// clang-format off
#define WMMA_VARIANTS(base) \
{{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
// clang-format on
static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
const std::array<unsigned, 4> Variants) {
if (Stride) {
if (Variant == WMMA_VARIANT_ARI64)
Variant = WMMA_VARIANT_ARI64_STRIDE;
else if (Variant == WMMA_VARIANT_AVAR)
Variant = WMMA_VARIANT_AVAR_STRIDE;
}
return Variants[Variant];
}
static Optional<unsigned>
getWmmaLdStOpcode(unsigned IntrinsicID,
WmmaVariant Variant = WMMA_VARIANT_ARI64) {
switch (IntrinsicID) {
default:
return None;
//
// WMMA_LOAD_A f16
//
case Intrinsic::nvvm_wmma_load_a_f16_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
case Intrinsic::nvvm_wmma_load_a_f16_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
//
// WMMA_LOAD_B f16
//
case Intrinsic::nvvm_wmma_load_b_f16_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
case Intrinsic::nvvm_wmma_load_b_f16_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
//
// WMMA_LOAD_C f16
//
case Intrinsic::nvvm_wmma_load_c_f16_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
case Intrinsic::nvvm_wmma_load_c_f16_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
//
// WMMA_LOAD_C f32
//
case Intrinsic::nvvm_wmma_load_c_f32_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
case Intrinsic::nvvm_wmma_load_c_f32_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
//
// WMMA_STORE_D f16
//
case Intrinsic::nvvm_wmma_store_d_f16_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
case Intrinsic::nvvm_wmma_store_d_f16_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
//
// WMMA_STORE_D f32
//
case Intrinsic::nvvm_wmma_store_d_f32_col:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
case Intrinsic::nvvm_wmma_store_d_f32_row:
return getWmmaLdVariant(Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
return getWmmaLdVariant(Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
return getWmmaLdVariant(
Variant, /*Stride=*/false,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
return getWmmaLdVariant(
Variant, /*Stride=*/true,
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
}
}
#undef WMMA_VARIANTS
bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
if (getWmmaLdStOpcode(IID))
return tryWMMA_LDST(N);
switch (IID) {
default:
return false;
@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
case Intrinsic::nvvm_match_all_sync_i64p:
SelectMatchAll(N);
return true;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
return tryWMMA_MMA(N);
}
}
@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
}
}
}
bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
SDValue Chain = N->getOperand(0);
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
SDValue Op1 = N->getOperand(2);
SDValue Addr, Offset, Base;
Optional<unsigned> Opcode;
SDLoc DL(N);
MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
WmmaVariant Variant;
SmallVector<SDValue, 12> Ops;
bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
if (SelectDirectAddr(Op1, Addr)) {
Variant = WMMA_VARIANT_AVAR;
Ops.push_back(Addr);
} else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
Variant = WMMA_VARIANT_ARI64;
Ops.push_back(Base);
Ops.push_back(Offset);
} else {
Variant = WMMA_VARIANT_AVAR;
Ops.push_back(Op1);
}
unsigned NumOps = N->getNumOperands();
// Pass through the rest of the operands to the machine node.
for (unsigned i = 3; i < NumOps; ++i)
Ops.push_back(N->getOperand(i));
Ops.push_back(Chain);
Opcode = getWmmaLdStOpcode(IID, Variant);
if (!Opcode) {
llvm::errs() << "tryWMMALD - no Opcode.\n";
return false;
}
EVT MemVT = MemSD->getMemoryVT();
assert(MemVT.isVector() && "Expected vector return type.");
SDNode *MN;
if (isStore) {
MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
} else {
SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
MemSD->getValueType(0));
InstVTs.push_back(MVT::Other);
MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
}
ReplaceNode(N, MN);
return true;
}
bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
SDLoc DL(N);
unsigned Opc;
switch (IID) {
default:
return false;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
break;
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
break;
}
SmallVector<SDValue, 24> Ops;
// Pass through operands and return value types to the machine node.
for (unsigned i = 1; i < N->getNumOperands(); ++i)
Ops.push_back(N->getOperand(i));
SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
ReplaceNode(N, MN);
return true;
}

View File

@ -74,6 +74,8 @@ private:
bool tryConstantFP16(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
bool tryWMMA_LDST(SDNode *N);
bool tryWMMA_MMA(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);

View File

@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
switch (Intrinsic) {
default:
return false;
case Intrinsic::nvvm_wmma_load_a_f16_col:
case Intrinsic::nvvm_wmma_load_a_f16_row:
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col:
case Intrinsic::nvvm_wmma_load_b_f16_row:
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.vol = false;
Info.readMem = true;
Info.writeMem = false;
Info.align = 16;
return true;
}
case Intrinsic::nvvm_wmma_load_c_f16_col:
case Intrinsic::nvvm_wmma_load_c_f16_row:
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.vol = false;
Info.readMem = true;
Info.writeMem = false;
Info.align = 16;
return true;
}
case Intrinsic::nvvm_wmma_load_c_f32_col:
case Intrinsic::nvvm_wmma_load_c_f32_row:
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.vol = false;
Info.readMem = true;
Info.writeMem = false;
Info.align = 16;
return true;
}
case Intrinsic::nvvm_wmma_store_d_f16_col:
case Intrinsic::nvvm_wmma_store_d_f16_row:
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.vol = false;
Info.readMem = false;
Info.writeMem = true;
Info.align = 16;
return true;
}
case Intrinsic::nvvm_wmma_store_d_f32_col:
case Intrinsic::nvvm_wmma_store_d_f32_row:
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.vol = false;
Info.readMem = false;
Info.writeMem = true;
Info.align = 16;
return true;
}
case Intrinsic::nvvm_atomic_load_add_f32:
case Intrinsic::nvvm_atomic_load_inc_32:

View File

@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32<"pm3", int_nvvm_read_ptx_sreg_pm3>;
def INT_PTX_SREG_WARPSIZE :
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
//
// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
Operand SrcOp, int WithOffset, int WithStride>
: NVPTXInst<!if(!eq(Abc#Type,"cf16"),
(outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
(outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
!if(WithStride,
!if(WithOffset,
(ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
(ins SrcOp:$src, Int32Regs:$ldm)),
!if(WithOffset,
(ins SrcOp:$src, i32imm:$offset),
(ins SrcOp:$src))),
"wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
#!if(!eq(Abc#Type,"cf16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
#", "
#!if(WithOffset,"[$src+$offset]", "[$src]")
#!if(WithStride, ", $ldm", "")
#";",
[]>,
Requires<[hasPTX60, hasSM70]>;
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
Operand SrcOp, int WithOffset = 0> {
def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
WithOffset, 1>;
def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
WithOffset, 0>;
}
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass> {
defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
}
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>;
}
multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
}
defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
class WMMA_STORE_D_LSTOS<string Layout, string Space,
string Type, NVPTXRegClass regclass,
Operand DstOp, int WithOffset, int WithStride>
: NVPTXInst<(outs),
!if(!eq(Type,"f16"),
!if(WithStride,
!if(WithOffset,
(ins DstOp:$src, i32imm:$offset,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
Int32Regs:$ldm),
(ins DstOp:$src,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
Int32Regs:$ldm)),
!if(WithOffset,
(ins DstOp:$src, i32imm:$offset,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
(ins DstOp:$src,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
!if(WithStride,
!if(WithOffset,
(ins DstOp:$src, i32imm:$offset,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
Int32Regs:$ldm),
(ins DstOp:$src,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
Int32Regs:$ldm)),
!if(WithOffset,
(ins DstOp:$src, i32imm:$offset,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
(ins DstOp:$src,
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
"wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
#!if(WithOffset,"[$src+$offset], ", "[$src], ")
#!if(!eq(Type,"f16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
#!if(WithStride, ", $ldm", "")
#";",
[]>,
Requires<[hasPTX60, hasSM70]>;
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
string Type, NVPTXRegClass regclass,
Operand DstOp, int WithOffset = 0> {
def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
WithOffset, 1>;
def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
WithOffset, 0>;
}
multiclass WMMA_STORE_D_LST<string Layout, string Space,
string Type, NVPTXRegClass regclass> {
defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
}
multiclass WMMA_STORE_D_LT<string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>;
}
multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
}
defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
// WMMA.MMA
class WMMA_MMA_ABDCS<string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
: NVPTXInst<!if(!eq(DType,"f16"),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
!if(!eq(CType,"f16"),
(ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
(ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3,
c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)),
"wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
#DType#"."#CType#Satfinite
#"\n\t\t"
#!if(!eq(DType,"f16"),
"{{$d0, $d1, $d2, $d3}}, \n\t\t",
"{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
#"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
#"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
#!if(!eq(CType,"f16"),
"{{$c0, $c1, $c2, $c3}};",
"{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
[]>,
Requires<[hasPTX60, hasSM70]>;
multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg> {
def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_reg, CType, c_reg,
Float16x2Regs, ".satfinite">;
def NAME: WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_reg, CType, c_reg,
Float16x2Regs>;
}
multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg> {
defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
}
multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
}
multiclass WMMA_MMA_A<string ALayout> {
defm _col: WMMA_MMA_AB<ALayout, "col">;
defm _row: WMMA_MMA_AB<ALayout, "row">;
}
defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;

201
test/CodeGen/NVPTX/wmma.py Normal file
View File

@ -0,0 +1,201 @@
# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them.
# RUN: python %s > %t.ll
# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
from itertools import product
from string import Template
def make_wmma_slice_ty(abcd, itype):
elt_ty = "<2 x half>" if itype == "f16" else "float"
num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
return [elt_ty] * num_elts
def make_wmma_ld_ret_ty(abc, itype):
return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
# Convenient test patterns.
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
def gen_wmma_load_tests():
load_template = """
declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
ret ${ret_ty} %v0;
}
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
%src1 = getelementptr i8, i8* %src, i32 128;
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
ret ${ret_ty} %v0;
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
for abc, layout, space, stride, itype in product(
"abc",
["row","col"],
["",".shared",".global"],
["", ".stride"],
["f16", "f32"]):
params = {
"abc" : abc,
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : itype
}
if itype == "f32" and abc != "c":
continue
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
if abc == "c" :
test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
else:
test_params["check_result"] = check_f16_8
if stride:
test_params["extra_args"] = ", i32 %stride";
test_params["stride_pattern"] = ", %r{{[0-9]+}}"
else:
test_params["extra_args"] = ""
test_params["stride_pattern"] = ""
print(Template(load_template).substitute(test_params))
def make_wmma_slice_args(itype, abcd, prefix="v"):
return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
in enumerate(make_wmma_slice_ty(abcd, itype))])
def gen_wmma_store_tests():
store_template = """
declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
ret void
}
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
%src1 = getelementptr i8, i8* %src, i32 128;
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
ret void
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
for abc, layout, space, stride, itype in product(
"d",
["row","col"],
["",".shared",".global"],
["", ".stride"],
["f16", "f32"]):
params = {
"abc" : abc,
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : itype
}
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
if stride:
test_params["extra_args"] = ", i32 %stride";
test_params["stride_pattern"] = ", %r{{[0-9]+}};"
else:
test_params["extra_args"] = ""
test_params["stride_pattern"] = ";"
test_params["args"] = make_wmma_slice_args(itype, "d");
print(Template(store_template).substitute(test_params))
def gen_wmma_mma_tests():
mma_template = """
declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
${args});
; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
define ${ret_ty} @test_wmma_mma_${function_suffix}(
${args}) {
; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
; CHECK ${check_d}
; CHECK ${check_ab}
; CHECK ${check_ab}
; CHECK ${check_c}
%r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
${args});
ret ${ret_ty} %r;
}
"""
suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
for alayout, blayout, ctype, dtype, satf in product(
["row","col"],
["row","col"],
["f16", "f32"],
["f16", "f32"],
[".satfinite", ""]):
params = {
"alayout" : alayout,
"blayout" : blayout,
"ctype" : ctype,
"dtype" : dtype,
"satf" : satf
}
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
test_params["check_ab"] = check_f16_8
test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
for abcd, t in (("a", "f16"),
("b", "f16"),
("c", ctype)))
test_params["args"] = args
print(Template(mma_template).substitute(test_params))
def main():
gen_wmma_load_tests()
gen_wmma_store_tests()
gen_wmma_mma_tests()
main()