1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 02:33:06 +01:00

[Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions

Adds NVPTX builtins and intrinsics for the CUDA PTX `wmma.load`, `wmma.store`, `wmma.mma`, and `mma` instructions added in PTX 6.5 and 7.0.

PTX ISA description of

  - `wmma.load`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-ld
  - `wmma.store`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-st
  - `wmma.mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma
  - `mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma

Overview of `wmma.mma` and `mma` matrix shape/type combinations added with specific PTX versions: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape

Authored-by: Steffen Larsen <steffen.larsen@codeplay.com>
Co-Authored-by: Stuart Adams <stuart.adams@codeplay.com>

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D104847
This commit is contained in:
Steffen Larsen 2021-06-28 15:43:10 -07:00 committed by Artem Belevich
parent 22a7347912
commit e40bb79a12
6 changed files with 913 additions and 255 deletions

View File

@ -52,13 +52,27 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops
// mma fp ops use smaller fragments than wmma fp ops
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
!eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
!eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
// fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All currently supported geometries use the same fragment format,
// so we only need to consider {fragment, type}.
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
// f16, so we only need to consider {fragment, type}.
!eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
!eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
!eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
@ -66,7 +80,36 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
!eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
// u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
// wmma tf32 -> s32 @ m16n16k8
!eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
// mma tf32 -> s32 @ m16n16k8/m16n8k8
!eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
!eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
!eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
!eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
!eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
// mma bf16 -> s32 @ m16n8k16/m16n8k8
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
// wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
!eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
@ -88,17 +131,65 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
!eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
// u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
// mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
!eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
!eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
!eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
!eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
!eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
!eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
!eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
// wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
!eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
!eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
!eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
!eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
!eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
!eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
// wmma/mma b1 -> s32 @ m8n8k128(b1)
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
!eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
);
}
@ -125,39 +216,44 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
// int and sub-int ops are identified by input type.
!eq(A.ptx_elt_type, "s8") : [A],
!eq(A.ptx_elt_type, "u8") : [A],
!eq(A.ptx_elt_type, "s4") : [A],
!eq(A.ptx_elt_type, "u4") : [A],
!eq(A.ptx_elt_type, "b1") : [A],
// the rest are FP ops identified by accumulator & result type.
true: [D, C]
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
!ne(A.ptx_elt_type, B.ptx_elt_type): [A, B],
true: [A]
);
string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
}
class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string llvm = !if(
!eq(A.geom, "m8n8k4"),
"llvm.nvvm.mma.m8n8k4"
# "." # ALayout
# "." # BLayout
# signature,
"llvm.nvvm.wmma."
# A.geom
# ".mma"
# "." # ALayout
# "." # BLayout
# signature
# !if(Satfinite, ".satfinite", ""));
string llvm = "llvm.nvvm.wmma."
# A.geom
# ".mma"
# "." # ALayout
# "." # BLayout
# !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
# signature
# !if(Satfinite, ".satfinite", "");
string record = !subst(".", "_",
!subst("llvm.", "int_", llvm));
}
class MMA_NAME<string ALayout, string BLayout, int Satfinite,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string llvm = "llvm.nvvm.mma."
# A.geom
# "." # ALayout
# "." # BLayout
# !if(Satfinite, ".satfinite", "")
# signature;
string record = !subst(".", "_",
!subst("llvm.", "int_", llvm));
}
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
@ -188,14 +284,18 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
// Creates list of valid combinations of fragments. This is the master list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS<int _ = 0> {
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
["m16n16k8"],
["tf32"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> bf16_wmma_ops = MMA_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> f64_wmma_ops = MMA_OPS<
["m8n8k4"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
@ -208,16 +308,50 @@ class NVVM_MMA_OPS<int _ = 0> {
list<list<WMMA_REGS>> bit_wmma_ops = MMA_OPS<
["m8n8k128"],
["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
tf32_wmma_ops, bf16_wmma_ops, f64_wmma_ops,
fp_wmma_ops, int_wmma_ops, subint_wmma_ops, bit_wmma_ops);
list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
["m16n8k4", "m16n8k8"],
["tf32"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
["m16n8k16", "m16n8k8"],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
["m8n8k4"],
["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
["m8n8k4", "m16n8k8", "m16n8k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
["m8n8k16", "m16n8k16", "m16n8k32"],
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
["m8n8k32", "m16n8k32", "m16n8k64"],
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
["m8n8k128", "m16n8k128", "m16n8k256"],
["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
fp_mma_ops, fp_wmma_ops, int_wmma_ops,
subint_wmma_ops, bit_wmma_ops);
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8"]>.ret;
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]>.ret;
list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
["m16n16k8"],
["a", "b"], ["tf32"]>.ret;
list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
["m16n16k8"],
["c", "d"], ["f32"]>.ret;
list<WMMA_REGS> ldst_f64_abcd_ops = MMA_LDST_OPS<
["m8n8k4"],
["a", "b", "c", "d"], ["f64"]>.ret;
list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
@ -225,6 +359,9 @@ class NVVM_MMA_OPS<int _ = 0> {
list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret;
list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
ldst_tf32_ab_ops,
ldst_tf32_cd_ops,
ldst_f64_abcd_ops,
ldst_subint_ab_ops,
ldst_bit_ab_ops,
ldst_subint_cd_ops);
@ -235,69 +372,110 @@ class NVVM_MMA_OPS<int _ = 0> {
def NVVM_MMA_OPS : NVVM_MMA_OPS;
// Returns true if this combination of layout/satf is supported; false otherwise.
// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
// The class is used to prevent generation of records for the unsupported variants.
// Returns true if this combination of fragment and layout for WMMA load/store
// ops is supported; false otherwise.
// E.g.
// if NVVM_WMMA_LDST_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_WMMA_LDST_SUPPORTED<WMMA_REGS frag, string layout> {
string f = frag.frag;
string t = frag.ptx_elt_type;
bit ret = !cond(
// Sub-int load and store requires A fragment to be of row layout and B
// fragments to be of column layout.
!and(!or(!eq(t, "b1"),
!eq(t, "u4"),
!eq(t, "s4")),
!or(!and(!eq(f, "a"),
!ne(layout, "row")),
!and(!eq(f, "b"),
!ne(layout, "col")))) : false,
true: true
);
}
// Returns true if this combination of layout/satf/rnd for WMMA ops is
// supported; false otherwise.
// E.g.
// if NVVM_WMMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf, string rnd> {
// WMMA ops check both layouts.
string layout = layout_a # ":" # layout_b;
string t = frags[0].ptx_elt_type;
bit ret = !cond(
// only f64 wmma functions support rnd options
// any non f64 type that uses a rnd value is invalid
!and(!ne(t, "f64"), !ne(rnd, "")) : false,
// satf is only valid for select types
!and(!eq(satf, 1),
!ne(t, "s8"),
!ne(t, "u8"),
!ne(t, "s4"),
!ne(t, "u4"),
!ne(t, "f16")): false,
// Sub-int wmma requires row/column layout
!and(!or(!eq(t, "s4"),
!eq(t, "u4"),
!eq(t, "b1")),
!ne(layout, "row:col")) : false,
true: true
);
}
// Returns true if this combination of layout/satf for MMA ops is supported;
// false otherwise.
// E.g.
// if NVVM_MMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
// MMA ops check both layouts.
string mma = frags[0].ptx_elt_type
# ":" # layout_a
# ":" # layout_b;
// Load ops only need type/fragment/layout.
string ld = frags[0].ptx_elt_type
# ":" # frags[0].frag
# ":" # layout_a
;
string ldf = frags[0].ptx_elt_type
# ":" # frags[0].frag
;
string t = frags[0].ptx_elt_type;
string layout = layout_a # ":" # layout_b;
string a_type = frags[0].ptx_elt_type;
string b_type = frags[1].ptx_elt_type;
string c_type = frags[2].ptx_elt_type;
string d_type = frags[3].ptx_elt_type;
string geom = frags[0].geom;
// gcd is a shortcut used to identify instructions that depend on
// geom+frag_c+frag_d. Not all instances of this class have all fragments
// specified. If there are not enough fragments, the tail evaluates to '?'.
string gcd = frags[0].geom
# ":"
# !if(!eq(!size(frags), 4),
frags[2].ptx_elt_type # frags[3].ptx_elt_type,
"?");
// geom+frag_c+frag_d.
string gcd = geom # ":" # c_type # d_type;
bit ret = !cond(
// Sub-int MMA only supports fixed A/B layout.
// b1 does not support .satf.
!eq(mma#":"#satf, "b1:row:col:0") : true,
// mma.m8n8k4 has no .satf modifier.
!and(!eq(frags[0].geom, "m8n8k4"),
!ne(satf, 0)): false,
// mma.m8n8k4 has no C=f32 D=f16 variant.
// Limit satf to valid types
!and(!eq(satf, 1),
!ne(a_type, "s8"),
!ne(a_type, "u8"),
!ne(a_type, "s4"),
!ne(a_type, "u4")): false,
// m8n8k4 has no C=f32 D=f16 variant.
!eq(gcd, "m8n8k4:f32f16"): false,
!eq(mma, "s4:row:col") : true,
!eq(mma, "u4:row:col") : true,
!eq(mma, "s4:row:col") : true,
!eq(mma, "u4:row:col") : true,
// Sub-int load/stores have fixed layout for A and B.
!and(!eq(layout_b, "-"), // It's a Load or Store op
!or(!eq(ld, "b1:a:row"),
!eq(ld, "b1:b:col"),
!eq(ldf, "b1:c"),
!eq(ldf, "b1:d"),
!eq(ld, "s4:a:row"),
!eq(ld, "s4:b:col"),
!eq(ldf, "s4:c"),
!eq(ldf, "s4:d"),
!eq(ld, "u4:a:row"),
!eq(ld, "u4:b:col"),
!eq(ldf, "u4:c"),
!eq(ldf, "u4:d"))) : true,
// All other sub-int ops are not supported.
!eq(t, "b1") : false,
!eq(t, "s4") : false,
!eq(t, "u4") : false,
// All other (non sub-int) are OK.
// only m8n8k4 for f16 does not require row:col layout
!and(!ne(layout, "row:col"),
!or(!ne(geom, "m8n8k4"),
!ne(a_type, "f16"))) : false,
// m16n8k8 requires A and B to be the same type and C and D to be the same
// type.
!and(!eq(geom, "m16n8k8"),
!or(!ne(a_type, b_type),
!ne(c_type, d_type))): false,
// m16n8k8 requires C and D to be the same type.
!and(!eq(geom, "m16n8k8"),
!ne(c_type, d_type)): false,
// All other are OK.
true: true
);
}
@ -4271,36 +4449,59 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
foreach layout = ["row", "col"] in {
foreach stride = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ld_ops in
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def WMMA_NAME_LDST<"load", frag, layout, stride>.record
: NVVM_WMMA_LD<frag, layout, stride>;
foreach frag = NVVM_MMA_OPS.all_st_ops in
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def WMMA_NAME_LDST<"store", frag, layout, stride>.record
: NVVM_WMMA_ST<frag, layout, stride>;
}
}
// WMMA.MMA
class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
def WMMA_NAME<layout_a, layout_b, satf, rnd,
op[0], op[1], op[2], op[3]>.record
: NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
op[0], op[1], op[2], op[3]>;
}
} // op
} // rnd
} // satf
} // layout_b
} // layout_a
// MMA
class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
def WMMA_NAME_MMA<layout_a, layout_b, satf,
op[0], op[1], op[2], op[3]>.record
: NVVM_WMMA_MMA<layout_a, layout_b, satf,
op[0], op[1], op[2], op[3]>;
def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
: NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
}
}
} // op
} // satf
} // layout_b
} // layout_a

View File

@ -3490,6 +3490,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
@ -3497,7 +3501,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
@ -3515,6 +3523,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
@ -3523,7 +3539,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@ -3603,7 +3627,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@ -3613,6 +3641,16 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
@ -3651,6 +3689,37 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::f64;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOLoad;
Info.align = Align(8);
return true;
}
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2f64;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOLoad;
Info.align = Align(16);
return true;
}
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
@ -3683,7 +3752,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@ -3731,6 +3804,19 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v2f64;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOStore;
Info.align = Align(16);
return true;
}
case Intrinsic::nvvm_atomic_load_inc_32:
case Intrinsic::nvvm_atomic_load_dec_32:

View File

@ -144,6 +144,7 @@ def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;

View File

@ -1943,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ldu.global.", TyStr), []>;
}
multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins Int32Regs:$src),
regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins Int64Regs:$src),
regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins MEMri:$src),
regclass:$dst4), (ins MEMri:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins MEMri64:$src),
regclass:$dst4), (ins MEMri64:$src),
!strconcat("ldu.global.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins imemAny:$src),
regclass:$dst4), (ins imemAny:$src),
!strconcat("ldu.global.", TyStr), []>;
}
@ -1997,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE
//-----------------------------------
// Support for ldg on sm_35 or later
// Support for ldg on sm_35 or later
//-----------------------------------
// Don't annotate ld.global.nc as mayLoad, because these loads go through the
@ -2045,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64
// vector
// Elementized vector ldg
// Elementized vector ldg
multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
(ins Int32Regs:$src),
@ -2064,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ld.global.nc.", TyStr), []>;
}
multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins Int32Regs:$src),
regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins Int64Regs:$src),
regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins MEMri:$src),
regclass:$dst4), (ins MEMri:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins MEMri64:$src),
regclass:$dst4), (ins MEMri64:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
regclass:$dst4), (ins imemAny:$src),
regclass:$dst4), (ins imemAny:$src),
!strconcat("ld.global.nc.", TyStr), []>;
}
@ -7568,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
class WMMA_REGINFO<WMMA_REGS r>
class WMMA_REGINFO<WMMA_REGS r, string op>
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "f16") : Float16x2Regs,
!eq(ptx_elt_type, "f32") : Float32Regs,
!eq(ptx_elt_type, "f64") : Float64Regs,
!eq(ptx_elt_type, "bf16") : Int32Regs,
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
@ -7602,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r>
!or(!eq(ptx_elt_type, "f16"),
!eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
!and(!eq(geom,"m8n8k4"),
!eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70],
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@ -7616,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r>
!eq(ptx_elt_type, "s8"),
!eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
// u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
!or(!eq(geom,"m8n8k128"),
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
!and(!or(!eq(geom,"m16n16k16"),
!eq(geom,"m8n32k16"),
!eq(geom,"m32n8k16")),
!eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70],
!eq(geom, "m8n8k4") : [hasSM70, hasPTX64]);
!and(!eq(geom,"m16n16k8"),
!eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70],
!and(!eq(geom,"m16n16k8"),
!eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70],
// b1 -> s32 @ m8n8k128(b1)
!and(!ne(op,"mma"),
!eq(geom,"m8n8k128")) : [hasSM75, hasPTX63],
// u4/s4 -> s32 @ m8n8k32 (u4/s4)
!and(!ne(op,"mma"),
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
!or(!eq(geom,"m16n8k8"),
!eq(geom,"m8n8k16")) : [hasSM75, hasPTX65],
!and(!ne(ptx_elt_type,"f64"),
!eq(geom, "m8n8k4")) : [hasSM70, hasPTX64],
// mma m8n8k32 requires higher PTX version
!and(!eq(op,"mma"),
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX65],
!and(!eq(ptx_elt_type,"f64"),
!eq(geom, "m8n8k4")) : [hasSM80, hasPTX70],
!and(!eq(op,"mma"),
!or(!eq(geom, "m16n8k16"),
!eq(geom, "m16n8k4"),
!eq(geom, "m16n8k32"),
!eq(geom, "m16n8k64"),
!eq(geom, "m8n8k128"),
!eq(geom, "m16n8k128"),
!eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@ -7744,11 +7785,11 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
foreach space = [".global", ".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ld_ops in
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>;
foreach frag = NVVM_MMA_OPS.all_st_ops in
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>;
} // addr
} // space
} // stride
@ -7758,46 +7799,84 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
// WMMA.MMA
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite>
: WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
string ALayout, string BLayout, int Satfinite, string rnd>
: WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
Requires<FragA.Predicates> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = !cond(
!eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type
# ".f16.f16."
# FragC.ptx_elt_type,
!eq(FragD.ptx_elt_type, "s32") : ".s32"
# "." # FragA.ptx_elt_type
# "." # FragB.ptx_elt_type
# ".s32",
1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
!eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type
# "." # FragC.ptx_elt_type,
1: "." # FragD.ptx_elt_type
# "." # FragA.ptx_elt_type
# "." # FragB.ptx_elt_type
# "." # FragC.ptx_elt_type,
);
let AsmString = !if(!eq(FragA.geom, "m8n8k4"),
"mma.sync.aligned.m8n8k4"
# "." # ALayout
# "." # BLayout
# TypeList # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
# FragC.regstring # ";",
"wmma.mma"
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
# ".sync"
# "${ptx:aligned}"
# "." # ALayout
# "." # BLayout
# "." # FragA.geom
# TypeList
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
# FragC.regstring # ";");
let AsmString = "wmma.mma"
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
# ".sync"
# "${ptx:aligned}"
# "." # ALayout
# "." # BLayout
# "." # FragA.geom
# !if(!ne(rnd, ""), !strconcat(".", rnd), "")
# TypeList
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
# FragC.regstring # ";";
}
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
WMMA_REGINFO<op[1], "wmma.mma">,
WMMA_REGINFO<op[2], "wmma.mma">,
WMMA_REGINFO<op[3], "wmma.mma">,
layout_a, layout_b, satf, rnd>;
}
} // op
} // rnd
} // satf
} // layout_b
} // layout_a
} // defset
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite>
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
Requires<FragA.Predicates> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = "." # FragD.ptx_elt_type
# "." # FragA.ptx_elt_type
# "." # FragB.ptx_elt_type
# "." # FragC.ptx_elt_type;
let AsmString = "mma.sync.aligned."
# FragA.geom
# "." # ALayout
# "." # BLayout
# !if(Satfinite, ".satfinite", "")
# TypeList
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
# FragC.regstring # ";";
}
defset list<WMMA_INSTR> MMAs = {
@ -7806,11 +7885,11 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
def : WMMA_MMA<WMMA_REGINFO<op[0]>,
WMMA_REGINFO<op[1]>,
WMMA_REGINFO<op[2]>,
WMMA_REGINFO<op[3]>,
layout_a, layout_b, satf>;
def : MMA<WMMA_REGINFO<op[0], "mma">,
WMMA_REGINFO<op[1], "mma">,
WMMA_REGINFO<op[2], "mma">,
WMMA_REGINFO<op[3], "mma">,
layout_a, layout_b, satf>;
}
} // op
} // satf
@ -7822,12 +7901,12 @@ defset list<WMMA_INSTR> MMAs = {
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
class WMMA_PAT<WMMA_INSTR wi>
class MMA_PAT<WMMA_INSTR wi>
: Pat<wi.IntrinsicPattern,
!con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
(wi ptx.version))>,
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
foreach mma = !listconcat(MMAs, MMA_LDSTs) in
def : WMMA_PAT<mma>;
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
def : MMA_PAT<mma>;

View File

@ -1,2 +1,3 @@
if not 'NVPTX' in config.root.targets:
config.unsupported = True
config.suffixes.add('.py')

View File

@ -6,7 +6,7 @@
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA
# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
# RUN: | FileCheck %t-ptx60-sm_70.ll
@ -15,7 +15,7 @@
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
# RUN: | FileCheck %t-ptx61-sm_70.ll
@ -24,7 +24,7 @@
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA
# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_72.ll
@ -33,7 +33,7 @@
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,NOMMA
# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_75.ll
@ -42,10 +42,28 @@
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
# RUN: | FileCheck %t-ptx64-sm_70.ll
# Check all variants of instructions supported by PTX65 on SM75+
# RUN: python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS
# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
# RUN: | FileCheck %t-ptx65-sm_75.ll
# Check all variants of instructions supported by PTX70 on SM80+
# RUN: python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
# RUN: --check-prefixes=INTRINSICS
# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
# RUN: | FileCheck %t-ptx70-sm_80.ll
from __future__ import print_function
import argparse
@ -56,19 +74,23 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
"f16" : "<2 x half>",
"f32" : "float",
"s32" : "i32",
"s8" : "i32",
"u8" : "i32",
"s4" : "i32",
"u4" : "i32",
"b1" : "i32",
"f16" : "<2 x half>",
"f32" : "float",
"f64" : "double",
"s32" : "i32",
"s8" : "i32",
"u8" : "i32",
"s4" : "i32",
"u4" : "i32",
"b1" : "i32",
"bf16" : "i32",
"tf32" : "i32",
}[ptx_type];
self.ptx_reg_pattern = {
"f16" : "%hh[0-9]+",
"f32" : "%f[0-9]+",
"f64" : "%fd[0-9]+",
}.get(ptx_type, "%r[0-9]+")
def __repr__(self):
@ -78,16 +100,8 @@ class MMAFrag:
def __init__(self, geom, frag, ptx_elt_type):
self.geom = geom
self.frag = frag
self.is_mma = True if geom == "m8n8k4" else False;
self.mma_type = MMAType(ptx_elt_type);
self.nregs = {
"a:f16" : 2 if self.is_mma else 8,
"b:f16" : 2 if self.is_mma else 8,
"c:f16" : 4,
"d:f16" : 4,
"c:f32" : 8,
"d:f32" : 8,
}.get("%s:%s" % (frag, ptx_elt_type), {
# u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
"m16n16k16:a:u8" : 2,
"m16n16k16:a:s8" : 2,
@ -110,18 +124,123 @@ class MMAFrag:
"m32n8k16:c:s32" : 8,
"m32n8k16:d:s32" : 8,
# u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
"m8n8k128:a:b1" : 1,
"m8n8k16:a:u8": 1,
"m8n8k16:a:s8": 1,
"m8n8k16:b:u8": 1,
"m8n8k16:b:s8": 1,
"m8n8k16:c:s32": 2,
"m8n8k16:d:s32": 2,
"m16n8k16:a:u8": 2,
"m16n8k16:a:s8": 2,
"m16n8k16:b:u8": 1,
"m16n8k16:b:s8": 1,
"m16n8k16:c:s32": 4,
"m16n8k16:d:s32": 4,
"m16n8k32:a:u8": 4,
"m16n8k32:a:s8": 4,
"m16n8k32:b:u8": 2,
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
# u4/s4 -> s32 @ m8n8k32 (u4/s4)
"m8n8k32:a:u4" : 1,
"m8n8k32:a:s4" : 1,
"m8n8k128:b:b1" : 1,
"m8n8k32:b:u4" : 1,
"m8n8k32:b:s4" : 1,
"m8n8k128:c:s32" : 2,
"m8n8k128:d:s32" : 2,
"m8n8k32:c:s32" : 2,
"m8n8k32:d:s32" : 2,
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
"m16n8k32:a:u4" : 2,
"m16n8k32:a:s4" : 2,
"m16n8k32:b:u4" : 1,
"m16n8k32:b:s4" : 1,
"m16n8k32:c:s32" : 4,
"m16n8k32:d:s32" : 4,
"m16n8k64:a:u4" : 4,
"m16n8k64:a:s4" : 4,
"m16n8k64:b:u4" : 2,
"m16n8k64:b:s4" : 2,
"m16n8k64:c:s32" : 4,
"m16n8k64:d:s32" : 4,
# b1 -> s32 @ m8n8k128(b1)
"m8n8k128:a:b1" : 1,
"m8n8k128:b:b1" : 1,
"m8n8k128:c:s32" : 2,
"m8n8k128:d:s32" : 2,
"m16n8k128:a:b1" : 2,
"m16n8k128:b:b1" : 1,
"m16n8k128:c:s32" : 4,
"m16n8k128:d:s32" : 4,
"m16n8k256:a:b1" : 4,
"m16n8k256:b:b1" : 2,
"m16n8k256:c:s32" : 4,
"m16n8k256:d:s32" : 4,
# bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
"m16n16k16:a:bf16" : 4,
"m16n16k16:b:bf16" : 4,
"m8n32k16:a:bf16" : 2,
"m8n32k16:b:bf16" : 8,
"m32n8k16:a:bf16" : 8,
"m32n8k16:b:bf16" : 2,
"m16n8k16:a:bf16" : 4,
"m16n8k16:b:bf16" : 2,
"m16n8k16:c:f32" : 4,
"m16n8k16:d:f32" : 4,
"m16n8k8:a:bf16" : 2,
"m16n8k8:b:bf16" : 1,
"m16n8k8:c:f32" : 4,
"m16n8k8:d:f32" : 4,
"m8n8k4:a:f64" : 1,
"m8n8k4:b:f64" : 1,
"m8n8k4:c:f64" : 2,
"m8n8k4:d:f64" : 2,
# tf32 -> s32 @ m16n16k8
"m16n16k8:a:tf32" : 4,
"m16n16k8:b:tf32" : 4,
"m16n8k4:a:tf32" : 2,
"m16n8k4:b:tf32" : 1,
"m16n8k4:c:f32" : 4,
"m16n8k4:d:f32" : 4,
"m16n8k8:a:tf32" : 4,
"m16n8k8:b:tf32" : 2,
"m16n8k8:c:f32" : 4,
"m16n8k8:d:f32" : 4,
"m8n8k4:a:f16": 2,
"m8n8k4:b:f16": 2,
"m16n8k8:a:f16": 2,
"m16n8k8:b:f16": 1,
"m16n8k8:c:f16": 2,
"m16n8k8:d:f16": 2,
"m16n8k8:c:f32": 4,
"m16n8k8:d:f32": 4,
"m16n8k16:a:f16": 4,
"m16n8k16:b:f16": 2,
"m16n8k16:c:f16": 2,
"m16n8k16:d:f16": 2,
"m16n8k16:c:f32": 4,
"m16n8k16:d:f32": 4,
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
# All other FP shape/fragment/type combinations have the same size
"a:f16" : 8,
"b:f16" : 8,
"c:f16" : 4,
"d:f16" : 4,
"c:f32" : 8,
"d:f32" : 8,
}.get("%s:%s" % (frag, ptx_elt_type), None))
assert(self.nregs);
def __repr__(self):
@ -153,9 +272,13 @@ def make_ldst_ops(geoms, frags, types):
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
in product(geoms, frags, types)]
def get_mma_ops():
return (make_mma_ops(["m8n8k4"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
def get_wmma_ops():
return (make_mma_ops(["m16n16k8"],
["tf32"], [], ["f32"], []) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["bf16"], [], ["f32"], []) +
make_mma_ops(["m8n8k4"],
["f64"], [], ["f64"], []) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
@ -164,20 +287,38 @@ def get_mma_ops():
["s4", "u4"], [], ["s32"], []) +
make_mma_ops(["m8n8k128"],
["b1"], [], ["s32"], []))
def get_mma_ops():
return (make_mma_ops(["m8n8k4"],
["f64"], [], ["f64"], []) +
make_mma_ops(["m16n8k4", "m16n8k8"],
["tf32"], [], ["f32"], []) +
make_mma_ops(["m16n8k16", "m16n8k8"],
["bf16"], [], ["f32"], []) +
make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
["s8", "u8"], ["s8", "u8"], ["s32"], []) +
make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
["s4", "u4"], ["s4", "u4"], ["s32"], []) +
make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
["b1"], [], ["s32"], []))
def get_ldst_ops(kind):
ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8"]) +
["a", "b"], ["f16", "u8", "s8", "bf16"]) +
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]) +
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
def is_geom_supported(geom):
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom == "m8n8k4":
return ptx_version >= 64
if geom in ["m8n32k16", "m32n8k16"]:
return ptx_version >= 61
# geometries for sub-ints.
@ -185,6 +326,21 @@ def is_geom_supported(geom):
return ptx_version >= 63 and gpu_arch >= 75
if geom == "m16n16k16":
return ptx_version >= 60
if geom == "m16n8k8":
return ptx_version >= 65
if geom in ["m16n16k8", "m8n8k4"]:
return ptx_version >= 70
assert(False) # Unexpected geometry.
def is_mma_geom_supported(geom):
# geometries for FP and ints.
if geom == "m8n8k4":
return ptx_version >= 64
if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
return ptx_version >= 65
if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
"m16n8k128", "m16n8k256"]:
return ptx_version >= 70
assert(False) # Unexpected geometry.
def is_type_supported(ptx_type):
@ -192,30 +348,63 @@ def is_type_supported(ptx_type):
return ptx_version >= 63 and gpu_arch >= 72
if ptx_type in ["s4", "u4", "b1"]:
return ptx_version >= 63 and gpu_arch >= 75
if ptx_type in ["bf16", "tf32", "f64"]:
return ptx_version >= 70
return ptx_version >= 60 and gpu_arch >= 70
def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
if not (is_type_supported(op.a.mma_type.ptx_type)
and is_wmma_geom_supported(op.a.geom)):
return False
# rnd is only supported for FP64 WMMA
if rnd and op.a.mma_type.ptx_type != "f64":
return False
if satf:
# satfinite for floating points was removed in PTX 6.5
if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
return False
if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
return False
# sub-integer require row/col layout.
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
return layout_a == "row" and layout_b == "col"
return True
def is_mma_variant_supported(op, layout_a, layout_b, satf):
if not (is_type_supported(op.a.mma_type.ptx_type)
and is_geom_supported(op.a.geom)):
and is_mma_geom_supported(op.a.geom)):
return False
if op.a.geom == "m8n8k4":
if satf:
return False
if op.c.mma_type.ptx_type == "f32":
# If C is f32, D must be, too.
return op.d.mma_type.ptx_type == "f32"
# sub-integer require row/col layout, and no satf.
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
if op.a.mma_type.ptx_type == "b1" and satf:
if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
return False
# If the type of C is f32 then so must the type of D
if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
and op.d.mma_type.ptx_type != "f32"):
return False
# A and B type must be the same. C and D type must be the same
if (op.a.geom == "m16n8k8"
and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
return False
# C and D type must be the same
if (op.a.geom == "m16n8k16"
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
return False
# Require row/col layout for all MMA except m8n8k4 on FP16
if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
return layout_a == "row" and layout_b == "col"
return True
def is_ldst_variant_supported(frag, layout):
if not (is_type_supported(frag.mma_type.ptx_type)
and is_geom_supported(frag.geom)):
and is_wmma_geom_supported(frag.geom)):
return False
if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
# sub-integer require sm_75 and ptx63, row/col layout for a/b.
@ -396,24 +585,37 @@ define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
return generated_items
def mma_signature(op):
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
# int and sub-int ops are identified by input type.
return op.a.mma_type.ptx_type
else:
# the rest are FP ops identified by accumulator & result type.
if op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
# other ops are identified by input types.
return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
else:
# if input types are the same, it only appears once.
return op.a.mma_type.ptx_type
def mma_ptx_signature(op):
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
# int and sub-int instructions encode all four types as D.A.B.C
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
if op.a.geom == "m8n8k4":
return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
else:
# the rest are FP instructions use D.C
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
# Encode all four types as D.A.B.C
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
def gen_wmma_mma_tests():
def wmma_signature(op):
if op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
else:
# other ops are identified by input type.
return op.a.mma_type.ptx_type
def wmma_ptx_signature(op):
if op.a.mma_type.ptx_type == "f16":
# FP16 instructions use D.C
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
else:
# other instructions encode all four types as D.A.B.C
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
mma_template = """
declare ${ret_ty} @${intrinsic}(
${args});
@ -431,10 +633,61 @@ define ${ret_ty} @test_${function}(
ret ${ret_ty} %r;
}
"""
wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}"
mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}"
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
test_params["check_a"] = check_pattern(op.a)
test_params["check_b"] = check_pattern(op.b)
test_params["check_c"] = check_pattern(op.c)
test_params["check_d"] = check_pattern(op.d)
args = ",\n ".join(make_wmma_slice_args(frag)
for frag in (op.a, op.b, op.c))
test_params["args"] = args
print(Template(mma_template).substitute(test_params))
return (test_params["intrinsic"], test_params["instruction"])
def gen_wmma_mma_tests():
wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
generated_items=[]
for op, alayout, blayout, rnd, satf in product(
get_wmma_ops(),
["row","col"],
["row","col"],
[".rn", ".rz", ".rm", ".rp", ""],
[".satfinite", ""]):
if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
continue
params = {
"aligned" : ".aligned" if ptx_version >= 63 else "",
"alayout" : alayout,
"blayout" : blayout,
"intrinsic_signature" : wmma_signature(op),
"ptx_signature" : wmma_ptx_signature(op),
"satf" : satf,
"rnd" : rnd,
"geom" : op.a.geom,
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
}
intrinsic_template = wmma_intrinsic_template
instruction_template = wmma_instruction_template
generated_items.append(common_mma_test_gen(params, op,
intrinsic_template, instruction_template))
return generated_items
def gen_mma_tests():
mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
generated_items=[]
@ -458,28 +711,11 @@ define ${ret_ty} @test_${function}(
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
}
if op.a.geom == "m8n8k4":
intrinsic_template = mma_intrinsic_template
instruction_template = mma_instruction_template
else:
intrinsic_template = wmma_intrinsic_template
instruction_template = wmma_instruction_template
intrinsic_template = mma_intrinsic_template
instruction_template = mma_instruction_template
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
test_params["check_a"] = check_pattern(op.a)
test_params["check_b"] = check_pattern(op.b)
test_params["check_c"] = check_pattern(op.c)
test_params["check_d"] = check_pattern(op.d)
args = ",\n ".join(make_wmma_slice_args(frag)
for frag in (op.a, op.b, op.c))
test_params["args"] = args
print(Template(mma_template).substitute(test_params))
generated_items.append((test_params["intrinsic"],
test_params["instruction"]))
generated_items.append(common_mma_test_gen(params, op,
intrinsic_template, instruction_template))
return generated_items
@ -497,6 +733,8 @@ def gen_check_unsupported_ops(items):
; NOINT-NOT: .{{s32|s8}}
; NOSUBINT-NOT: {{s4|u4|b1}}
; NOMMA-NOT: .m8n8k4.
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@ -543,10 +781,61 @@ def gen_check_unsupported_ops(items):
; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
;
""")
@ -561,6 +850,7 @@ def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
gen_check_unsupported_ops(items)
parser = argparse.ArgumentParser()