mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-23 11:13:28 +01:00
[NVPTX] Implemented wmma intrinsics and instructions.
WMMA = "Warp Level Matrix Multiply-Accumulate". These are the new instructions introduced in PTX6.0 and available on sm_70 GPUs. Differential Revision: https://reviews.llvm.org/D38645 llvm-svn: 315601
This commit is contained in:
parent
9f8f153184
commit
848056d8ad
@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
|
||||
Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
|
||||
[IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
|
||||
|
||||
//
|
||||
// WMMA instructions
|
||||
//
|
||||
|
||||
// WMMA.LOAD
|
||||
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
|
||||
string Type, LLVMType regty, int WithStride>
|
||||
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
|
||||
[regty, regty, regty, regty],
|
||||
[regty, regty, regty, regty,
|
||||
regty, regty, regty, regty]),
|
||||
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
|
||||
[], // Properties must be set during instantiation.
|
||||
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
|
||||
#Space
|
||||
#!if(WithStride,".stride","")
|
||||
#"."#Type>;
|
||||
|
||||
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
|
||||
string Type, LLVMType regty> {
|
||||
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
|
||||
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
|
||||
string Type, LLVMType regty> {
|
||||
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
|
||||
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
|
||||
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
|
||||
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
|
||||
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
|
||||
}
|
||||
|
||||
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
|
||||
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
|
||||
// outside of the multiclass works.
|
||||
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
|
||||
ReadOnly<0>, NoCapture<0>] in {
|
||||
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
|
||||
}
|
||||
|
||||
// WMMA.STORE.D
|
||||
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
|
||||
string Type, LLVMType regty, int WithStride,
|
||||
// This is only used to create a typed empty array we
|
||||
// need to pass to !if below.
|
||||
list<LLVMType>Empty=[]>
|
||||
: Intrinsic<[],
|
||||
!listconcat(
|
||||
[llvm_ptr_ty],
|
||||
!if(!eq(Type,"f16"),
|
||||
[regty, regty, regty, regty],
|
||||
[regty, regty, regty, regty,
|
||||
regty, regty, regty, regty]),
|
||||
!if(WithStride, [llvm_i32_ty], Empty)),
|
||||
[], // Properties must be set during instantiation.
|
||||
"llvm.nvvm.wmma.store.d.sync."#Layout
|
||||
#".m16n16k16"#Space
|
||||
#!if(WithStride,".stride","")
|
||||
#"."#Type>;
|
||||
|
||||
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
|
||||
string Type, LLVMType regty> {
|
||||
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
|
||||
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
|
||||
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
|
||||
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
|
||||
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
|
||||
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
|
||||
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
|
||||
}
|
||||
|
||||
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
|
||||
WriteOnly<0>, NoCapture<0>] in {
|
||||
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
|
||||
}
|
||||
|
||||
// WMMA.MMA
|
||||
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
|
||||
string DType, LLVMType d_regty,
|
||||
string CType, LLVMType c_regty,
|
||||
string Satfinite = "">
|
||||
: Intrinsic<!if(!eq(DType,"f16"),
|
||||
[d_regty, d_regty, d_regty, d_regty],
|
||||
[d_regty, d_regty, d_regty, d_regty,
|
||||
d_regty, d_regty, d_regty, d_regty]),
|
||||
!listconcat(
|
||||
[// A
|
||||
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
|
||||
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
|
||||
// B
|
||||
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
|
||||
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
|
||||
!if(!eq(CType,"f16"),
|
||||
[c_regty, c_regty, c_regty, c_regty],
|
||||
[c_regty, c_regty, c_regty, c_regty,
|
||||
c_regty, c_regty, c_regty, c_regty])),
|
||||
[IntrNoMem],
|
||||
"llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
|
||||
#".m16n16k16."#DType#"."#CType#Satfinite>;
|
||||
|
||||
multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
|
||||
string DType, LLVMType d_regty,
|
||||
string CType, LLVMType c_regty> {
|
||||
def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
|
||||
DType, d_regty,
|
||||
CType, c_regty>;
|
||||
def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
|
||||
DType, d_regty,
|
||||
CType, c_regty,".satfinite">;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
|
||||
string DType, LLVMType d_regty> {
|
||||
defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
|
||||
"f16", llvm_v2f16_ty>;
|
||||
defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
|
||||
"f32", llvm_float_ty>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
|
||||
defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
|
||||
defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_MMA_A<string ALayout> {
|
||||
defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
|
||||
defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
|
||||
}
|
||||
|
||||
defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
|
||||
defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
|
||||
|
||||
} // let TargetPrefix = "nvvm"
|
||||
|
@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
|
||||
SelectCode(N);
|
||||
}
|
||||
|
||||
// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
|
||||
// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
|
||||
// look up the intrinsic ID of particular variant.
|
||||
enum WmmaVariant {
|
||||
WMMA_VARIANT_ARI64,
|
||||
WMMA_VARIANT_ARI64_STRIDE,
|
||||
WMMA_VARIANT_AVAR,
|
||||
WMMA_VARIANT_AVAR_STRIDE,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
#define WMMA_VARIANTS(base) \
|
||||
{{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
|
||||
// clang-format on
|
||||
|
||||
static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
|
||||
const std::array<unsigned, 4> Variants) {
|
||||
if (Stride) {
|
||||
if (Variant == WMMA_VARIANT_ARI64)
|
||||
Variant = WMMA_VARIANT_ARI64_STRIDE;
|
||||
else if (Variant == WMMA_VARIANT_AVAR)
|
||||
Variant = WMMA_VARIANT_AVAR_STRIDE;
|
||||
}
|
||||
return Variants[Variant];
|
||||
}
|
||||
|
||||
static Optional<unsigned>
|
||||
getWmmaLdStOpcode(unsigned IntrinsicID,
|
||||
WmmaVariant Variant = WMMA_VARIANT_ARI64) {
|
||||
switch (IntrinsicID) {
|
||||
default:
|
||||
return None;
|
||||
//
|
||||
// WMMA_LOAD_A f16
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
|
||||
|
||||
//
|
||||
// WMMA_LOAD_B f16
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
|
||||
|
||||
//
|
||||
// WMMA_LOAD_C f16
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
|
||||
|
||||
//
|
||||
// WMMA_LOAD_C f32
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
|
||||
|
||||
//
|
||||
// WMMA_STORE_D f16
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
|
||||
|
||||
//
|
||||
// WMMA_STORE_D f32
|
||||
//
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
|
||||
return getWmmaLdVariant(Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/false,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
|
||||
return getWmmaLdVariant(
|
||||
Variant, /*Stride=*/true,
|
||||
WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
|
||||
}
|
||||
}
|
||||
#undef WMMA_VARIANTS
|
||||
|
||||
bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
|
||||
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
|
||||
if (getWmmaLdStOpcode(IID))
|
||||
return tryWMMA_LDST(N);
|
||||
|
||||
switch (IID) {
|
||||
default:
|
||||
return false;
|
||||
@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
|
||||
case Intrinsic::nvvm_match_all_sync_i64p:
|
||||
SelectMatchAll(N);
|
||||
return true;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
|
||||
return tryWMMA_MMA(N);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
|
||||
SDValue Chain = N->getOperand(0);
|
||||
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
|
||||
SDValue Op1 = N->getOperand(2);
|
||||
SDValue Addr, Offset, Base;
|
||||
Optional<unsigned> Opcode;
|
||||
SDLoc DL(N);
|
||||
MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
|
||||
WmmaVariant Variant;
|
||||
SmallVector<SDValue, 12> Ops;
|
||||
bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
|
||||
|
||||
if (SelectDirectAddr(Op1, Addr)) {
|
||||
Variant = WMMA_VARIANT_AVAR;
|
||||
Ops.push_back(Addr);
|
||||
} else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
|
||||
SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
|
||||
Variant = WMMA_VARIANT_ARI64;
|
||||
Ops.push_back(Base);
|
||||
Ops.push_back(Offset);
|
||||
} else {
|
||||
Variant = WMMA_VARIANT_AVAR;
|
||||
Ops.push_back(Op1);
|
||||
}
|
||||
unsigned NumOps = N->getNumOperands();
|
||||
// Pass through the rest of the operands to the machine node.
|
||||
for (unsigned i = 3; i < NumOps; ++i)
|
||||
Ops.push_back(N->getOperand(i));
|
||||
Ops.push_back(Chain);
|
||||
|
||||
Opcode = getWmmaLdStOpcode(IID, Variant);
|
||||
if (!Opcode) {
|
||||
llvm::errs() << "tryWMMALD - no Opcode.\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
EVT MemVT = MemSD->getMemoryVT();
|
||||
assert(MemVT.isVector() && "Expected vector return type.");
|
||||
|
||||
SDNode *MN;
|
||||
if (isStore) {
|
||||
MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
|
||||
} else {
|
||||
SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
|
||||
MemSD->getValueType(0));
|
||||
InstVTs.push_back(MVT::Other);
|
||||
MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
|
||||
}
|
||||
|
||||
ReplaceNode(N, MN);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
|
||||
unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
|
||||
SDLoc DL(N);
|
||||
unsigned Opc;
|
||||
|
||||
switch (IID) {
|
||||
default:
|
||||
return false;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
|
||||
break;
|
||||
case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
|
||||
Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
|
||||
break;
|
||||
}
|
||||
|
||||
SmallVector<SDValue, 24> Ops;
|
||||
// Pass through operands and return value types to the machine node.
|
||||
for (unsigned i = 1; i < N->getNumOperands(); ++i)
|
||||
Ops.push_back(N->getOperand(i));
|
||||
SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
|
||||
SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
|
||||
ReplaceNode(N, MN);
|
||||
return true;
|
||||
}
|
||||
|
@ -74,6 +74,8 @@ private:
|
||||
bool tryConstantFP16(SDNode *N);
|
||||
bool SelectSETP_F16X2(SDNode *N);
|
||||
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
|
||||
bool tryWMMA_LDST(SDNode *N);
|
||||
bool tryWMMA_MMA(SDNode *N);
|
||||
|
||||
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
|
||||
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
|
||||
|
@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
switch (Intrinsic) {
|
||||
default:
|
||||
return false;
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.vol = false;
|
||||
Info.readMem = true;
|
||||
Info.writeMem = false;
|
||||
Info.align = 16;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v4f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.vol = false;
|
||||
Info.readMem = true;
|
||||
Info.writeMem = false;
|
||||
Info.align = 16;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.vol = false;
|
||||
Info.readMem = true;
|
||||
Info.writeMem = false;
|
||||
Info.align = 16;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v4f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.vol = false;
|
||||
Info.readMem = false;
|
||||
Info.writeMem = true;
|
||||
Info.align = 16;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.vol = false;
|
||||
Info.readMem = false;
|
||||
Info.writeMem = true;
|
||||
Info.align = 16;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_atomic_load_add_f32:
|
||||
case Intrinsic::nvvm_atomic_load_inc_32:
|
||||
|
@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32<"pm3", int_nvvm_read_ptx_sreg_pm3>;
|
||||
def INT_PTX_SREG_WARPSIZE :
|
||||
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
|
||||
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
|
||||
|
||||
//
|
||||
// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
|
||||
//
|
||||
class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
Operand SrcOp, int WithOffset, int WithStride>
|
||||
: NVPTXInst<!if(!eq(Abc#Type,"cf16"),
|
||||
(outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
|
||||
(outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
|
||||
!if(WithStride,
|
||||
!if(WithOffset,
|
||||
(ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
|
||||
(ins SrcOp:$src, Int32Regs:$ldm)),
|
||||
!if(WithOffset,
|
||||
(ins SrcOp:$src, i32imm:$offset),
|
||||
(ins SrcOp:$src))),
|
||||
"wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
|
||||
#!if(!eq(Abc#Type,"cf16"),
|
||||
"{{$r0, $r1, $r2, $r3}}",
|
||||
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
|
||||
#", "
|
||||
#!if(WithOffset,"[$src+$offset]", "[$src]")
|
||||
#!if(WithStride, ", $ldm", "")
|
||||
#";",
|
||||
[]>,
|
||||
Requires<[hasPTX60, hasSM70]>;
|
||||
|
||||
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
Operand SrcOp, int WithOffset = 0> {
|
||||
def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
|
||||
WithOffset, 1>;
|
||||
def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
|
||||
WithOffset, 0>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
|
||||
defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
|
||||
defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
|
||||
defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
|
||||
defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
|
||||
defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
|
||||
}
|
||||
|
||||
defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
|
||||
defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
|
||||
defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
|
||||
defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
|
||||
|
||||
//
|
||||
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
|
||||
//
|
||||
class WMMA_STORE_D_LSTOS<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
Operand DstOp, int WithOffset, int WithStride>
|
||||
: NVPTXInst<(outs),
|
||||
!if(!eq(Type,"f16"),
|
||||
!if(WithStride,
|
||||
!if(WithOffset,
|
||||
(ins DstOp:$src, i32imm:$offset,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
Int32Regs:$ldm),
|
||||
(ins DstOp:$src,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
Int32Regs:$ldm)),
|
||||
!if(WithOffset,
|
||||
(ins DstOp:$src, i32imm:$offset,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
|
||||
(ins DstOp:$src,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
|
||||
!if(WithStride,
|
||||
!if(WithOffset,
|
||||
(ins DstOp:$src, i32imm:$offset,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
|
||||
Int32Regs:$ldm),
|
||||
(ins DstOp:$src,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
|
||||
Int32Regs:$ldm)),
|
||||
!if(WithOffset,
|
||||
(ins DstOp:$src, i32imm:$offset,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
|
||||
(ins DstOp:$src,
|
||||
regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
|
||||
regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
|
||||
"wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
|
||||
#!if(WithOffset,"[$src+$offset], ", "[$src], ")
|
||||
#!if(!eq(Type,"f16"),
|
||||
"{{$r0, $r1, $r2, $r3}}",
|
||||
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
|
||||
#!if(WithStride, ", $ldm", "")
|
||||
#";",
|
||||
[]>,
|
||||
Requires<[hasPTX60, hasSM70]>;
|
||||
|
||||
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
Operand DstOp, int WithOffset = 0> {
|
||||
def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
|
||||
WithOffset, 1>;
|
||||
def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
|
||||
WithOffset, 0>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LST<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
|
||||
defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LT<string Layout,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
|
||||
defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
|
||||
defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
|
||||
defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
|
||||
defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
|
||||
}
|
||||
|
||||
defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
|
||||
defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
|
||||
|
||||
// WMMA.MMA
|
||||
class WMMA_MMA_ABDCS<string ALayout, string BLayout,
|
||||
string DType, NVPTXRegClass d_reg,
|
||||
string CType, NVPTXRegClass c_reg,
|
||||
NVPTXRegClass ab_reg,
|
||||
string Satfinite = "">
|
||||
: NVPTXInst<!if(!eq(DType,"f16"),
|
||||
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
|
||||
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
|
||||
d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
|
||||
!if(!eq(CType,"f16"),
|
||||
(ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
|
||||
ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
|
||||
ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
|
||||
ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
|
||||
c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
|
||||
(ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
|
||||
ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
|
||||
ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
|
||||
ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
|
||||
c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3,
|
||||
c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)),
|
||||
"wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
|
||||
#DType#"."#CType#Satfinite
|
||||
#"\n\t\t"
|
||||
#!if(!eq(DType,"f16"),
|
||||
"{{$d0, $d1, $d2, $d3}}, \n\t\t",
|
||||
"{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
|
||||
#"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
|
||||
#"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
|
||||
#!if(!eq(CType,"f16"),
|
||||
"{{$c0, $c1, $c2, $c3}};",
|
||||
"{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
|
||||
[]>,
|
||||
Requires<[hasPTX60, hasSM70]>;
|
||||
|
||||
multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
|
||||
string DType, NVPTXRegClass d_reg,
|
||||
string CType, NVPTXRegClass c_reg> {
|
||||
def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
|
||||
DType, d_reg, CType, c_reg,
|
||||
Float16x2Regs, ".satfinite">;
|
||||
def NAME: WMMA_MMA_ABDCS<ALayout, BLayout,
|
||||
DType, d_reg, CType, c_reg,
|
||||
Float16x2Regs>;
|
||||
}
|
||||
|
||||
multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
|
||||
string DType, NVPTXRegClass d_reg> {
|
||||
defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
|
||||
defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
|
||||
}
|
||||
|
||||
multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
|
||||
defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
|
||||
defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
|
||||
}
|
||||
|
||||
multiclass WMMA_MMA_A<string ALayout> {
|
||||
defm _col: WMMA_MMA_AB<ALayout, "col">;
|
||||
defm _row: WMMA_MMA_AB<ALayout, "row">;
|
||||
}
|
||||
|
||||
defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
|
||||
defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
|
||||
|
||||
|
201
test/CodeGen/NVPTX/wmma.py
Normal file
201
test/CodeGen/NVPTX/wmma.py
Normal file
@ -0,0 +1,201 @@
|
||||
# This test generates all variants of wmma intrinsics and verifies that LLVM
|
||||
# generates correct instructions for them.
|
||||
|
||||
# RUN: python %s > %t.ll
|
||||
# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
|
||||
|
||||
from itertools import product
|
||||
from string import Template
|
||||
|
||||
def make_wmma_slice_ty(abcd, itype):
|
||||
elt_ty = "<2 x half>" if itype == "f16" else "float"
|
||||
num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
|
||||
return [elt_ty] * num_elts
|
||||
|
||||
def make_wmma_ld_ret_ty(abc, itype):
|
||||
return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
|
||||
|
||||
# Convenient test patterns.
|
||||
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
|
||||
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
|
||||
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
|
||||
|
||||
def gen_wmma_load_tests():
|
||||
load_template = """
|
||||
declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
|
||||
|
||||
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
|
||||
; CHECK wmma.load.${intrinsic_suffix}
|
||||
; CHECK: {${check_result}}
|
||||
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
|
||||
ret ${ret_ty} %v0;
|
||||
}
|
||||
|
||||
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
|
||||
; CHECK wmma.load.${intrinsic_suffix}
|
||||
; CHECK: {${check_result}}
|
||||
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
|
||||
%src1 = getelementptr i8, i8* %src, i32 128;
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
|
||||
ret ${ret_ty} %v0;
|
||||
}
|
||||
"""
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
|
||||
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
|
||||
|
||||
for abc, layout, space, stride, itype in product(
|
||||
"abc",
|
||||
["row","col"],
|
||||
["",".shared",".global"],
|
||||
["", ".stride"],
|
||||
["f16", "f32"]):
|
||||
|
||||
params = {
|
||||
"abc" : abc,
|
||||
"layout" : layout,
|
||||
"space" : space,
|
||||
"stride" : stride,
|
||||
"itype" : itype
|
||||
}
|
||||
|
||||
if itype == "f32" and abc != "c":
|
||||
continue
|
||||
|
||||
test_params = params
|
||||
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
|
||||
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
|
||||
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
|
||||
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
|
||||
if abc == "c" :
|
||||
test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
|
||||
else:
|
||||
test_params["check_result"] = check_f16_8
|
||||
|
||||
if stride:
|
||||
test_params["extra_args"] = ", i32 %stride";
|
||||
test_params["stride_pattern"] = ", %r{{[0-9]+}}"
|
||||
else:
|
||||
test_params["extra_args"] = ""
|
||||
test_params["stride_pattern"] = ""
|
||||
|
||||
print(Template(load_template).substitute(test_params))
|
||||
|
||||
def make_wmma_slice_args(itype, abcd, prefix="v"):
|
||||
return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
|
||||
in enumerate(make_wmma_slice_ty(abcd, itype))])
|
||||
|
||||
def gen_wmma_store_tests():
|
||||
store_template = """
|
||||
declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
|
||||
|
||||
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
|
||||
define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
|
||||
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
|
||||
; CHECK: {${check_args}}
|
||||
; CHECK: ${stride_pattern}
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
|
||||
ret void
|
||||
}
|
||||
|
||||
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
|
||||
define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
|
||||
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
|
||||
; CHECK: ${check_args}
|
||||
; CHECK: ${stride_pattern}
|
||||
%src1 = getelementptr i8, i8* %src, i32 128;
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
|
||||
ret void
|
||||
}
|
||||
"""
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
|
||||
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
|
||||
|
||||
for abc, layout, space, stride, itype in product(
|
||||
"d",
|
||||
["row","col"],
|
||||
["",".shared",".global"],
|
||||
["", ".stride"],
|
||||
["f16", "f32"]):
|
||||
|
||||
params = {
|
||||
"abc" : abc,
|
||||
"layout" : layout,
|
||||
"space" : space,
|
||||
"stride" : stride,
|
||||
"itype" : itype
|
||||
}
|
||||
|
||||
test_params = params
|
||||
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
|
||||
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
|
||||
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
|
||||
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
|
||||
test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
|
||||
if stride:
|
||||
test_params["extra_args"] = ", i32 %stride";
|
||||
test_params["stride_pattern"] = ", %r{{[0-9]+}};"
|
||||
else:
|
||||
test_params["extra_args"] = ""
|
||||
test_params["stride_pattern"] = ";"
|
||||
test_params["args"] = make_wmma_slice_args(itype, "d");
|
||||
|
||||
print(Template(store_template).substitute(test_params))
|
||||
|
||||
def gen_wmma_mma_tests():
|
||||
mma_template = """
|
||||
declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
|
||||
${args});
|
||||
|
||||
; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
|
||||
define ${ret_ty} @test_wmma_mma_${function_suffix}(
|
||||
${args}) {
|
||||
; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
|
||||
; CHECK ${check_d}
|
||||
; CHECK ${check_ab}
|
||||
; CHECK ${check_ab}
|
||||
; CHECK ${check_c}
|
||||
%r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
|
||||
${args});
|
||||
ret ${ret_ty} %r;
|
||||
}
|
||||
"""
|
||||
suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
|
||||
|
||||
for alayout, blayout, ctype, dtype, satf in product(
|
||||
["row","col"],
|
||||
["row","col"],
|
||||
["f16", "f32"],
|
||||
["f16", "f32"],
|
||||
[".satfinite", ""]):
|
||||
|
||||
params = {
|
||||
"alayout" : alayout,
|
||||
"blayout" : blayout,
|
||||
"ctype" : ctype,
|
||||
"dtype" : dtype,
|
||||
"satf" : satf
|
||||
}
|
||||
|
||||
test_params = params
|
||||
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
|
||||
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
|
||||
test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
|
||||
test_params["check_ab"] = check_f16_8
|
||||
test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
|
||||
test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
|
||||
args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
|
||||
for abcd, t in (("a", "f16"),
|
||||
("b", "f16"),
|
||||
("c", ctype)))
|
||||
test_params["args"] = args
|
||||
print(Template(mma_template).substitute(test_params))
|
||||
|
||||
def main():
|
||||
gen_wmma_load_tests()
|
||||
gen_wmma_store_tests()
|
||||
gen_wmma_mma_tests()
|
||||
|
||||
main()
|
Loading…
Reference in New Issue
Block a user