mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 02:33:06 +01:00
[Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions
Adds NVPTX builtins and intrinsics for the CUDA PTX `wmma.load`, `wmma.store`, `wmma.mma`, and `mma` instructions added in PTX 6.5 and 7.0. PTX ISA description of - `wmma.load`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-ld - `wmma.store`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-st - `wmma.mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma - `mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma Overview of `wmma.mma` and `mma` matrix shape/type combinations added with specific PTX versions: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape Authored-by: Steffen Larsen <steffen.larsen@codeplay.com> Co-Authored-by: Stuart Adams <stuart.adams@codeplay.com> Reviewed By: tra Differential Revision: https://reviews.llvm.org/D104847
This commit is contained in:
parent
22a7347912
commit
e40bb79a12
@ -52,13 +52,27 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
|
||||
string gft = Geom#":"#Frag#":"#ptx_elt_type;
|
||||
string ft = frag#":"#ptx_elt_type;
|
||||
list<LLVMType> regs = !cond(
|
||||
// mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops
|
||||
// mma fp ops use smaller fragments than wmma fp ops
|
||||
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
|
||||
!eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
|
||||
!eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
|
||||
!eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
|
||||
!eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
|
||||
!eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
|
||||
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
|
||||
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
|
||||
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
|
||||
|
||||
// fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
// All currently supported geometries use the same fragment format,
|
||||
// so we only need to consider {fragment, type}.
|
||||
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
// All other supported geometries use the same fragment format for f32 and
|
||||
// f16, so we only need to consider {fragment, type}.
|
||||
!eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
|
||||
!eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
|
||||
!eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
|
||||
@ -66,7 +80,36 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
|
||||
!eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
|
||||
!eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
|
||||
|
||||
// u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
// wmma tf32 -> s32 @ m16n16k8
|
||||
!eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
// mma tf32 -> s32 @ m16n16k8/m16n8k8
|
||||
!eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
|
||||
|
||||
!eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
|
||||
!eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
|
||||
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
|
||||
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
|
||||
|
||||
// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
|
||||
!eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
|
||||
!eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
|
||||
|
||||
// mma bf16 -> s32 @ m16n8k16/m16n8k8
|
||||
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
|
||||
|
||||
// wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
!eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
|
||||
@ -88,17 +131,65 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
|
||||
!eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
|
||||
!eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
|
||||
|
||||
// u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
|
||||
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
|
||||
// mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
|
||||
!eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
|
||||
!eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
!eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
// wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
|
||||
!eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
|
||||
!eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
!eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
// wmma/mma b1 -> s32 @ m8n8k128(b1)
|
||||
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
|
||||
!eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
|
||||
|
||||
!eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
|
||||
!eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
|
||||
!eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
|
||||
!eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
|
||||
);
|
||||
}
|
||||
|
||||
@ -125,39 +216,44 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
|
||||
|
||||
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
|
||||
list<WMMA_REGS> id_frags = !cond(
|
||||
// int and sub-int ops are identified by input type.
|
||||
!eq(A.ptx_elt_type, "s8") : [A],
|
||||
!eq(A.ptx_elt_type, "u8") : [A],
|
||||
!eq(A.ptx_elt_type, "s4") : [A],
|
||||
!eq(A.ptx_elt_type, "u4") : [A],
|
||||
!eq(A.ptx_elt_type, "b1") : [A],
|
||||
// the rest are FP ops identified by accumulator & result type.
|
||||
true: [D, C]
|
||||
// FP16 ops are identified by accumulator & result type.
|
||||
!eq(A.ptx_elt_type, "f16") : [D, C],
|
||||
// other ops are identified by input types.
|
||||
!ne(A.ptx_elt_type, B.ptx_elt_type): [A, B],
|
||||
true: [A]
|
||||
);
|
||||
string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
|
||||
}
|
||||
|
||||
class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
|
||||
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
|
||||
class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
|
||||
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
|
||||
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
|
||||
string llvm = !if(
|
||||
!eq(A.geom, "m8n8k4"),
|
||||
"llvm.nvvm.mma.m8n8k4"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# signature,
|
||||
"llvm.nvvm.wmma."
|
||||
# A.geom
|
||||
# ".mma"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# signature
|
||||
# !if(Satfinite, ".satfinite", ""));
|
||||
string llvm = "llvm.nvvm.wmma."
|
||||
# A.geom
|
||||
# ".mma"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
|
||||
# signature
|
||||
# !if(Satfinite, ".satfinite", "");
|
||||
|
||||
string record = !subst(".", "_",
|
||||
!subst("llvm.", "int_", llvm));
|
||||
}
|
||||
|
||||
class MMA_NAME<string ALayout, string BLayout, int Satfinite,
|
||||
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
|
||||
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
|
||||
string llvm = "llvm.nvvm.mma."
|
||||
# A.geom
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# !if(Satfinite, ".satfinite", "")
|
||||
# signature;
|
||||
string record = !subst(".", "_",
|
||||
!subst("llvm.", "int_", llvm));
|
||||
}
|
||||
|
||||
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
|
||||
// Geom: list of supported geometries.
|
||||
// TypeN: PTX type of the corresponding fragment's element.
|
||||
@ -188,14 +284,18 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
|
||||
list<string> ops = !foreach(x, ret, x.gft);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Creates list of valid combinations of fragments. This is the master list that
|
||||
// drives generation of corresponding intrinsics and instructions.
|
||||
class NVVM_MMA_OPS<int _ = 0> {
|
||||
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
|
||||
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
|
||||
["m16n16k8"],
|
||||
["tf32"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> bf16_wmma_ops = MMA_OPS<
|
||||
["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["bf16"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> f64_wmma_ops = MMA_OPS<
|
||||
["m8n8k4"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
|
||||
["f64"], [], ["f64"], []>.ret;
|
||||
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
|
||||
["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
|
||||
@ -208,16 +308,50 @@ class NVVM_MMA_OPS<int _ = 0> {
|
||||
list<list<WMMA_REGS>> bit_wmma_ops = MMA_OPS<
|
||||
["m8n8k128"],
|
||||
["b1"], [], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
|
||||
tf32_wmma_ops, bf16_wmma_ops, f64_wmma_ops,
|
||||
fp_wmma_ops, int_wmma_ops, subint_wmma_ops, bit_wmma_ops);
|
||||
|
||||
list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
|
||||
["m16n8k4", "m16n8k8"],
|
||||
["tf32"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
|
||||
["m16n8k16", "m16n8k8"],
|
||||
["bf16"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
|
||||
["m8n8k4"],
|
||||
["f64"], [], ["f64"], []>.ret;
|
||||
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
|
||||
["m8n8k4", "m16n8k8", "m16n8k16"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
|
||||
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
|
||||
["m8n8k16", "m16n8k16", "m16n8k32"],
|
||||
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
|
||||
["m8n8k32", "m16n8k32", "m16n8k64"],
|
||||
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
|
||||
["m8n8k128", "m16n8k128", "m16n8k256"],
|
||||
["b1"], [], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
|
||||
fp_mma_ops, fp_wmma_ops, int_wmma_ops,
|
||||
subint_wmma_ops, bit_wmma_ops);
|
||||
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
|
||||
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
|
||||
|
||||
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
|
||||
["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["a", "b"], ["f16", "u8", "s8"]>.ret;
|
||||
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
|
||||
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
|
||||
["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["c", "d"], ["f16", "f32", "s32"]>.ret;
|
||||
list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
|
||||
["m16n16k8"],
|
||||
["a", "b"], ["tf32"]>.ret;
|
||||
list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
|
||||
["m16n16k8"],
|
||||
["c", "d"], ["f32"]>.ret;
|
||||
list<WMMA_REGS> ldst_f64_abcd_ops = MMA_LDST_OPS<
|
||||
["m8n8k4"],
|
||||
["a", "b", "c", "d"], ["f64"]>.ret;
|
||||
list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
|
||||
["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
|
||||
list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
|
||||
@ -225,6 +359,9 @@ class NVVM_MMA_OPS<int _ = 0> {
|
||||
list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
|
||||
["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret;
|
||||
list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
|
||||
ldst_tf32_ab_ops,
|
||||
ldst_tf32_cd_ops,
|
||||
ldst_f64_abcd_ops,
|
||||
ldst_subint_ab_ops,
|
||||
ldst_bit_ab_ops,
|
||||
ldst_subint_cd_ops);
|
||||
@ -235,69 +372,110 @@ class NVVM_MMA_OPS<int _ = 0> {
|
||||
|
||||
def NVVM_MMA_OPS : NVVM_MMA_OPS;
|
||||
|
||||
// Returns true if this combination of layout/satf is supported; false otherwise.
|
||||
// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
|
||||
// The class is used to prevent generation of records for the unsupported variants.
|
||||
|
||||
// Returns true if this combination of fragment and layout for WMMA load/store
|
||||
// ops is supported; false otherwise.
|
||||
// E.g.
|
||||
// if NVVM_WMMA_LDST_SUPPORTED<...>.ret then
|
||||
// def : FOO<>; // The record will only be defined for supported ops.
|
||||
//
|
||||
class NVVM_WMMA_LDST_SUPPORTED<WMMA_REGS frag, string layout> {
|
||||
string f = frag.frag;
|
||||
string t = frag.ptx_elt_type;
|
||||
|
||||
bit ret = !cond(
|
||||
// Sub-int load and store requires A fragment to be of row layout and B
|
||||
// fragments to be of column layout.
|
||||
!and(!or(!eq(t, "b1"),
|
||||
!eq(t, "u4"),
|
||||
!eq(t, "s4")),
|
||||
!or(!and(!eq(f, "a"),
|
||||
!ne(layout, "row")),
|
||||
!and(!eq(f, "b"),
|
||||
!ne(layout, "col")))) : false,
|
||||
true: true
|
||||
);
|
||||
}
|
||||
|
||||
// Returns true if this combination of layout/satf/rnd for WMMA ops is
|
||||
// supported; false otherwise.
|
||||
// E.g.
|
||||
// if NVVM_WMMA_SUPPORTED<...>.ret then
|
||||
// def : FOO<>; // The record will only be defined for supported ops.
|
||||
//
|
||||
class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf, string rnd> {
|
||||
// WMMA ops check both layouts.
|
||||
string layout = layout_a # ":" # layout_b;
|
||||
string t = frags[0].ptx_elt_type;
|
||||
|
||||
bit ret = !cond(
|
||||
// only f64 wmma functions support rnd options
|
||||
// any non f64 type that uses a rnd value is invalid
|
||||
!and(!ne(t, "f64"), !ne(rnd, "")) : false,
|
||||
|
||||
// satf is only valid for select types
|
||||
!and(!eq(satf, 1),
|
||||
!ne(t, "s8"),
|
||||
!ne(t, "u8"),
|
||||
!ne(t, "s4"),
|
||||
!ne(t, "u4"),
|
||||
!ne(t, "f16")): false,
|
||||
|
||||
// Sub-int wmma requires row/column layout
|
||||
!and(!or(!eq(t, "s4"),
|
||||
!eq(t, "u4"),
|
||||
!eq(t, "b1")),
|
||||
!ne(layout, "row:col")) : false,
|
||||
true: true
|
||||
);
|
||||
}
|
||||
|
||||
// Returns true if this combination of layout/satf for MMA ops is supported;
|
||||
// false otherwise.
|
||||
// E.g.
|
||||
// if NVVM_MMA_SUPPORTED<...>.ret then
|
||||
// def : FOO<>; // The record will only be defined for supported ops.
|
||||
//
|
||||
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
|
||||
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
|
||||
// MMA ops check both layouts.
|
||||
string mma = frags[0].ptx_elt_type
|
||||
# ":" # layout_a
|
||||
# ":" # layout_b;
|
||||
// Load ops only need type/fragment/layout.
|
||||
string ld = frags[0].ptx_elt_type
|
||||
# ":" # frags[0].frag
|
||||
# ":" # layout_a
|
||||
;
|
||||
string ldf = frags[0].ptx_elt_type
|
||||
# ":" # frags[0].frag
|
||||
;
|
||||
string t = frags[0].ptx_elt_type;
|
||||
string layout = layout_a # ":" # layout_b;
|
||||
string a_type = frags[0].ptx_elt_type;
|
||||
string b_type = frags[1].ptx_elt_type;
|
||||
string c_type = frags[2].ptx_elt_type;
|
||||
string d_type = frags[3].ptx_elt_type;
|
||||
string geom = frags[0].geom;
|
||||
|
||||
// gcd is a shortcut used to identify instructions that depend on
|
||||
// geom+frag_c+frag_d. Not all instances of this class have all fragments
|
||||
// specified. If there are not enough fragments, the tail evaluates to '?'.
|
||||
string gcd = frags[0].geom
|
||||
# ":"
|
||||
# !if(!eq(!size(frags), 4),
|
||||
frags[2].ptx_elt_type # frags[3].ptx_elt_type,
|
||||
"?");
|
||||
// geom+frag_c+frag_d.
|
||||
string gcd = geom # ":" # c_type # d_type;
|
||||
bit ret = !cond(
|
||||
// Sub-int MMA only supports fixed A/B layout.
|
||||
// b1 does not support .satf.
|
||||
!eq(mma#":"#satf, "b1:row:col:0") : true,
|
||||
// mma.m8n8k4 has no .satf modifier.
|
||||
!and(!eq(frags[0].geom, "m8n8k4"),
|
||||
!ne(satf, 0)): false,
|
||||
|
||||
// mma.m8n8k4 has no C=f32 D=f16 variant.
|
||||
// Limit satf to valid types
|
||||
!and(!eq(satf, 1),
|
||||
!ne(a_type, "s8"),
|
||||
!ne(a_type, "u8"),
|
||||
!ne(a_type, "s4"),
|
||||
!ne(a_type, "u4")): false,
|
||||
|
||||
// m8n8k4 has no C=f32 D=f16 variant.
|
||||
!eq(gcd, "m8n8k4:f32f16"): false,
|
||||
!eq(mma, "s4:row:col") : true,
|
||||
!eq(mma, "u4:row:col") : true,
|
||||
!eq(mma, "s4:row:col") : true,
|
||||
!eq(mma, "u4:row:col") : true,
|
||||
// Sub-int load/stores have fixed layout for A and B.
|
||||
!and(!eq(layout_b, "-"), // It's a Load or Store op
|
||||
!or(!eq(ld, "b1:a:row"),
|
||||
!eq(ld, "b1:b:col"),
|
||||
!eq(ldf, "b1:c"),
|
||||
!eq(ldf, "b1:d"),
|
||||
!eq(ld, "s4:a:row"),
|
||||
!eq(ld, "s4:b:col"),
|
||||
!eq(ldf, "s4:c"),
|
||||
!eq(ldf, "s4:d"),
|
||||
!eq(ld, "u4:a:row"),
|
||||
!eq(ld, "u4:b:col"),
|
||||
!eq(ldf, "u4:c"),
|
||||
!eq(ldf, "u4:d"))) : true,
|
||||
// All other sub-int ops are not supported.
|
||||
!eq(t, "b1") : false,
|
||||
!eq(t, "s4") : false,
|
||||
!eq(t, "u4") : false,
|
||||
// All other (non sub-int) are OK.
|
||||
|
||||
// only m8n8k4 for f16 does not require row:col layout
|
||||
!and(!ne(layout, "row:col"),
|
||||
!or(!ne(geom, "m8n8k4"),
|
||||
!ne(a_type, "f16"))) : false,
|
||||
|
||||
// m16n8k8 requires A and B to be the same type and C and D to be the same
|
||||
// type.
|
||||
!and(!eq(geom, "m16n8k8"),
|
||||
!or(!ne(a_type, b_type),
|
||||
!ne(c_type, d_type))): false,
|
||||
|
||||
// m16n8k8 requires C and D to be the same type.
|
||||
!and(!eq(geom, "m16n8k8"),
|
||||
!ne(c_type, d_type)): false,
|
||||
|
||||
// All other are OK.
|
||||
true: true
|
||||
);
|
||||
}
|
||||
@ -4271,36 +4449,59 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
|
||||
foreach layout = ["row", "col"] in {
|
||||
foreach stride = [0, 1] in {
|
||||
foreach frag = NVVM_MMA_OPS.all_ld_ops in
|
||||
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
|
||||
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
|
||||
def WMMA_NAME_LDST<"load", frag, layout, stride>.record
|
||||
: NVVM_WMMA_LD<frag, layout, stride>;
|
||||
foreach frag = NVVM_MMA_OPS.all_st_ops in
|
||||
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
|
||||
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
|
||||
def WMMA_NAME_LDST<"store", frag, layout, stride>.record
|
||||
: NVVM_WMMA_ST<frag, layout, stride>;
|
||||
}
|
||||
}
|
||||
|
||||
// WMMA.MMA
|
||||
class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
|
||||
class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
|
||||
WMMA_REGS A, WMMA_REGS B,
|
||||
WMMA_REGS C, WMMA_REGS D>
|
||||
: Intrinsic<D.regs,
|
||||
!listconcat(A.regs, B.regs, C.regs),
|
||||
[IntrNoMem],
|
||||
WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
|
||||
WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
|
||||
|
||||
foreach layout_a = ["row", "col"] in {
|
||||
foreach layout_b = ["row", "col"] in {
|
||||
foreach satf = [0, 1] in {
|
||||
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
|
||||
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
|
||||
if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
|
||||
def WMMA_NAME<layout_a, layout_b, satf, rnd,
|
||||
op[0], op[1], op[2], op[3]>.record
|
||||
: NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
|
||||
op[0], op[1], op[2], op[3]>;
|
||||
}
|
||||
} // op
|
||||
} // rnd
|
||||
} // satf
|
||||
} // layout_b
|
||||
} // layout_a
|
||||
|
||||
// MMA
|
||||
class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
|
||||
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
|
||||
: Intrinsic<D.regs,
|
||||
!listconcat(A.regs, B.regs, C.regs),
|
||||
[IntrNoMem],
|
||||
MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
|
||||
|
||||
foreach layout_a = ["row", "col"] in {
|
||||
foreach layout_b = ["row", "col"] in {
|
||||
foreach satf = [0, 1] in {
|
||||
foreach op = NVVM_MMA_OPS.all_mma_ops in {
|
||||
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
|
||||
def WMMA_NAME_MMA<layout_a, layout_b, satf,
|
||||
op[0], op[1], op[2], op[3]>.record
|
||||
: NVVM_WMMA_MMA<layout_a, layout_b, satf,
|
||||
op[0], op[1], op[2], op[3]>;
|
||||
def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
|
||||
: NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
|
||||
}
|
||||
}
|
||||
} // op
|
||||
} // satf
|
||||
} // layout_b
|
||||
} // layout_a
|
||||
|
@ -3490,6 +3490,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
|
||||
@ -3497,7 +3501,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v2i32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
@ -3515,6 +3523,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
|
||||
@ -3523,7 +3539,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v4i32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
@ -3603,7 +3627,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
@ -3613,6 +3641,16 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
|
||||
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
|
||||
@ -3651,6 +3689,37 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::f64;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.flags = MachineMemOperand::MOLoad;
|
||||
Info.align = Align(8);
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v2f64;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.flags = MachineMemOperand::MOLoad;
|
||||
Info.align = Align(16);
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
|
||||
@ -3683,7 +3752,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
|
||||
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_VOID;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
@ -3731,6 +3804,19 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
|
||||
case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_VOID;
|
||||
Info.memVT = MVT::v2f64;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
Info.offset = 0;
|
||||
Info.flags = MachineMemOperand::MOStore;
|
||||
Info.align = Align(16);
|
||||
return true;
|
||||
}
|
||||
|
||||
case Intrinsic::nvvm_atomic_load_inc_32:
|
||||
case Intrinsic::nvvm_atomic_load_dec_32:
|
||||
|
||||
|
@ -144,6 +144,7 @@ def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
|
||||
def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
|
||||
def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
|
||||
def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
|
||||
def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
|
||||
def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
|
||||
|
||||
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
|
||||
|
@ -1943,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
}
|
||||
|
||||
multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
|
||||
multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
|
||||
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins Int32Regs:$src),
|
||||
regclass:$dst4), (ins Int32Regs:$src),
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins Int64Regs:$src),
|
||||
regclass:$dst4), (ins Int64Regs:$src),
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins MEMri:$src),
|
||||
regclass:$dst4), (ins MEMri:$src),
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins MEMri64:$src),
|
||||
regclass:$dst4), (ins MEMri64:$src),
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins imemAny:$src),
|
||||
regclass:$dst4), (ins imemAny:$src),
|
||||
!strconcat("ldu.global.", TyStr), []>;
|
||||
}
|
||||
|
||||
@ -1997,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE
|
||||
|
||||
|
||||
//-----------------------------------
|
||||
// Support for ldg on sm_35 or later
|
||||
// Support for ldg on sm_35 or later
|
||||
//-----------------------------------
|
||||
|
||||
// Don't annotate ld.global.nc as mayLoad, because these loads go through the
|
||||
@ -2045,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64
|
||||
|
||||
// vector
|
||||
|
||||
// Elementized vector ldg
|
||||
// Elementized vector ldg
|
||||
multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
|
||||
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
|
||||
(ins Int32Regs:$src),
|
||||
@ -2064,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
}
|
||||
|
||||
multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
|
||||
multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
|
||||
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins Int32Regs:$src),
|
||||
regclass:$dst4), (ins Int32Regs:$src),
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins Int64Regs:$src),
|
||||
regclass:$dst4), (ins Int64Regs:$src),
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins MEMri:$src),
|
||||
regclass:$dst4), (ins MEMri:$src),
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins MEMri64:$src),
|
||||
regclass:$dst4), (ins MEMri64:$src),
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
|
||||
regclass:$dst4), (ins imemAny:$src),
|
||||
regclass:$dst4), (ins imemAny:$src),
|
||||
!strconcat("ld.global.nc.", TyStr), []>;
|
||||
}
|
||||
|
||||
@ -7568,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE :
|
||||
// In addition to target-independent fields provided by WMMA_REGS, it adds
|
||||
// the fields commonly used to implement specific PTX instruction -- register
|
||||
// types and names, constraints, parts of assembly, etc.
|
||||
class WMMA_REGINFO<WMMA_REGS r>
|
||||
class WMMA_REGINFO<WMMA_REGS r, string op>
|
||||
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
|
||||
// NVPTX register types used to carry fragment data.
|
||||
NVPTXRegClass regclass = !cond(
|
||||
!eq(ptx_elt_type, "f16") : Float16x2Regs,
|
||||
!eq(ptx_elt_type, "f32") : Float32Regs,
|
||||
!eq(ptx_elt_type, "f64") : Float64Regs,
|
||||
!eq(ptx_elt_type, "bf16") : Int32Regs,
|
||||
!eq(ptx_elt_type, "tf32") : Int32Regs,
|
||||
!eq(ptx_elt_type, "s32") : Int32Regs,
|
||||
!eq(ptx_elt_type, "s8") : Int32Regs,
|
||||
!eq(ptx_elt_type, "u8") : Int32Regs,
|
||||
@ -7602,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r>
|
||||
!or(!eq(ptx_elt_type, "f16"),
|
||||
!eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
|
||||
|
||||
!and(!eq(geom,"m8n8k4"),
|
||||
!eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70],
|
||||
|
||||
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
|
||||
!and(!or(!eq(geom, "m8n32k16"),
|
||||
!eq(geom, "m32n8k16")),
|
||||
@ -7616,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r>
|
||||
!eq(ptx_elt_type, "s8"),
|
||||
!eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
|
||||
|
||||
// u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
|
||||
!or(!eq(geom,"m8n8k128"),
|
||||
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
|
||||
!and(!or(!eq(geom,"m16n16k16"),
|
||||
!eq(geom,"m8n32k16"),
|
||||
!eq(geom,"m32n8k16")),
|
||||
!eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70],
|
||||
|
||||
!eq(geom, "m8n8k4") : [hasSM70, hasPTX64]);
|
||||
!and(!eq(geom,"m16n16k8"),
|
||||
!eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70],
|
||||
|
||||
!and(!eq(geom,"m16n16k8"),
|
||||
!eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70],
|
||||
|
||||
// b1 -> s32 @ m8n8k128(b1)
|
||||
!and(!ne(op,"mma"),
|
||||
!eq(geom,"m8n8k128")) : [hasSM75, hasPTX63],
|
||||
|
||||
// u4/s4 -> s32 @ m8n8k32 (u4/s4)
|
||||
!and(!ne(op,"mma"),
|
||||
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
|
||||
|
||||
!or(!eq(geom,"m16n8k8"),
|
||||
!eq(geom,"m8n8k16")) : [hasSM75, hasPTX65],
|
||||
|
||||
!and(!ne(ptx_elt_type,"f64"),
|
||||
!eq(geom, "m8n8k4")) : [hasSM70, hasPTX64],
|
||||
|
||||
// mma m8n8k32 requires higher PTX version
|
||||
!and(!eq(op,"mma"),
|
||||
!eq(geom,"m8n8k32")) : [hasSM75, hasPTX65],
|
||||
|
||||
!and(!eq(ptx_elt_type,"f64"),
|
||||
!eq(geom, "m8n8k4")) : [hasSM80, hasPTX70],
|
||||
|
||||
!and(!eq(op,"mma"),
|
||||
!or(!eq(geom, "m16n8k16"),
|
||||
!eq(geom, "m16n8k4"),
|
||||
!eq(geom, "m16n8k32"),
|
||||
!eq(geom, "m16n8k64"),
|
||||
!eq(geom, "m8n8k128"),
|
||||
!eq(geom, "m16n8k128"),
|
||||
!eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
|
||||
|
||||
// template DAGs for instruction inputs/output.
|
||||
dag Outs = !dag(outs, ptx_regs, reg_names);
|
||||
@ -7744,11 +7785,11 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
|
||||
foreach space = [".global", ".shared", ""] in {
|
||||
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
|
||||
foreach frag = NVVM_MMA_OPS.all_ld_ops in
|
||||
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
|
||||
def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
|
||||
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
|
||||
def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>;
|
||||
foreach frag = NVVM_MMA_OPS.all_st_ops in
|
||||
if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
|
||||
def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
|
||||
if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
|
||||
def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>;
|
||||
} // addr
|
||||
} // space
|
||||
} // stride
|
||||
@ -7758,46 +7799,84 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
|
||||
// WMMA.MMA
|
||||
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
|
||||
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
|
||||
string ALayout, string BLayout, int Satfinite>
|
||||
: WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
|
||||
[FragA.Ins, FragB.Ins, FragC.Ins]>,
|
||||
string ALayout, string BLayout, int Satfinite, string rnd>
|
||||
: WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
|
||||
[FragA.Ins, FragB.Ins, FragC.Ins]>,
|
||||
// Requires does not seem to have effect on Instruction w/o Patterns.
|
||||
// We set it here anyways and propagate to the Pat<> we construct below.
|
||||
Requires<FragA.Predicates> {
|
||||
let OutOperandList = FragD.Outs;
|
||||
let InOperandList = !con(Args, (ins MmaCode:$ptx));
|
||||
string TypeList = !cond(
|
||||
!eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type
|
||||
# ".f16.f16."
|
||||
# FragC.ptx_elt_type,
|
||||
!eq(FragD.ptx_elt_type, "s32") : ".s32"
|
||||
# "." # FragA.ptx_elt_type
|
||||
# "." # FragB.ptx_elt_type
|
||||
# ".s32",
|
||||
1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
|
||||
!eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type
|
||||
# "." # FragC.ptx_elt_type,
|
||||
1: "." # FragD.ptx_elt_type
|
||||
# "." # FragA.ptx_elt_type
|
||||
# "." # FragB.ptx_elt_type
|
||||
# "." # FragC.ptx_elt_type,
|
||||
);
|
||||
let AsmString = !if(!eq(FragA.geom, "m8n8k4"),
|
||||
"mma.sync.aligned.m8n8k4"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# TypeList # "\n\t\t"
|
||||
# FragD.regstring # ",\n\t\t"
|
||||
# FragA.regstring # ",\n\t\t"
|
||||
# FragB.regstring # ",\n\t\t"
|
||||
# FragC.regstring # ";",
|
||||
"wmma.mma"
|
||||
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
|
||||
# ".sync"
|
||||
# "${ptx:aligned}"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# "." # FragA.geom
|
||||
# TypeList
|
||||
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
|
||||
# FragD.regstring # ",\n\t\t"
|
||||
# FragA.regstring # ",\n\t\t"
|
||||
# FragB.regstring # ",\n\t\t"
|
||||
# FragC.regstring # ";");
|
||||
let AsmString = "wmma.mma"
|
||||
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
|
||||
# ".sync"
|
||||
# "${ptx:aligned}"
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# "." # FragA.geom
|
||||
# !if(!ne(rnd, ""), !strconcat(".", rnd), "")
|
||||
# TypeList
|
||||
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
|
||||
# FragD.regstring # ",\n\t\t"
|
||||
# FragA.regstring # ",\n\t\t"
|
||||
# FragB.regstring # ",\n\t\t"
|
||||
# FragC.regstring # ";";
|
||||
}
|
||||
|
||||
defset list<WMMA_INSTR> WMMAs = {
|
||||
foreach layout_a = ["row", "col"] in {
|
||||
foreach layout_b = ["row", "col"] in {
|
||||
foreach satf = [0, 1] in {
|
||||
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
|
||||
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
|
||||
if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
|
||||
def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
|
||||
WMMA_REGINFO<op[1], "wmma.mma">,
|
||||
WMMA_REGINFO<op[2], "wmma.mma">,
|
||||
WMMA_REGINFO<op[3], "wmma.mma">,
|
||||
layout_a, layout_b, satf, rnd>;
|
||||
}
|
||||
} // op
|
||||
} // rnd
|
||||
} // satf
|
||||
} // layout_b
|
||||
} // layout_a
|
||||
} // defset
|
||||
|
||||
// MMA
|
||||
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
|
||||
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
|
||||
string ALayout, string BLayout, int Satfinite>
|
||||
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
|
||||
[FragA.Ins, FragB.Ins, FragC.Ins]>,
|
||||
// Requires does not seem to have effect on Instruction w/o Patterns.
|
||||
// We set it here anyways and propagate to the Pat<> we construct below.
|
||||
Requires<FragA.Predicates> {
|
||||
let OutOperandList = FragD.Outs;
|
||||
let InOperandList = !con(Args, (ins MmaCode:$ptx));
|
||||
string TypeList = "." # FragD.ptx_elt_type
|
||||
# "." # FragA.ptx_elt_type
|
||||
# "." # FragB.ptx_elt_type
|
||||
# "." # FragC.ptx_elt_type;
|
||||
let AsmString = "mma.sync.aligned."
|
||||
# FragA.geom
|
||||
# "." # ALayout
|
||||
# "." # BLayout
|
||||
# !if(Satfinite, ".satfinite", "")
|
||||
# TypeList
|
||||
# !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
|
||||
# FragD.regstring # ",\n\t\t"
|
||||
# FragA.regstring # ",\n\t\t"
|
||||
# FragB.regstring # ",\n\t\t"
|
||||
# FragC.regstring # ";";
|
||||
}
|
||||
|
||||
defset list<WMMA_INSTR> MMAs = {
|
||||
@ -7806,11 +7885,11 @@ defset list<WMMA_INSTR> MMAs = {
|
||||
foreach satf = [0, 1] in {
|
||||
foreach op = NVVM_MMA_OPS.all_mma_ops in {
|
||||
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
|
||||
def : WMMA_MMA<WMMA_REGINFO<op[0]>,
|
||||
WMMA_REGINFO<op[1]>,
|
||||
WMMA_REGINFO<op[2]>,
|
||||
WMMA_REGINFO<op[3]>,
|
||||
layout_a, layout_b, satf>;
|
||||
def : MMA<WMMA_REGINFO<op[0], "mma">,
|
||||
WMMA_REGINFO<op[1], "mma">,
|
||||
WMMA_REGINFO<op[2], "mma">,
|
||||
WMMA_REGINFO<op[3], "mma">,
|
||||
layout_a, layout_b, satf>;
|
||||
}
|
||||
} // op
|
||||
} // satf
|
||||
@ -7822,12 +7901,12 @@ defset list<WMMA_INSTR> MMAs = {
|
||||
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
|
||||
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
|
||||
// the instruction record.
|
||||
class WMMA_PAT<WMMA_INSTR wi>
|
||||
class MMA_PAT<WMMA_INSTR wi>
|
||||
: Pat<wi.IntrinsicPattern,
|
||||
!con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
|
||||
(wi ptx.version))>,
|
||||
Requires<wi.Predicates>;
|
||||
|
||||
// Build intrinsic->instruction patterns for all MMA instructions.
|
||||
foreach mma = !listconcat(MMAs, MMA_LDSTs) in
|
||||
def : WMMA_PAT<mma>;
|
||||
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
|
||||
def : MMA_PAT<mma>;
|
||||
|
@ -1,2 +1,3 @@
|
||||
if not 'NVPTX' in config.root.targets:
|
||||
config.unsupported = True
|
||||
config.suffixes.add('.py')
|
||||
|
@ -6,7 +6,7 @@
|
||||
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16
|
||||
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA
|
||||
# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
||||
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
|
||||
# RUN: | FileCheck %t-ptx60-sm_70.ll
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
|
||||
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA
|
||||
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
||||
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
|
||||
# RUN: | FileCheck %t-ptx61-sm_70.ll
|
||||
|
||||
@ -24,7 +24,7 @@
|
||||
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
|
||||
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA
|
||||
# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
||||
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
|
||||
# RUN: | FileCheck %t-ptx63-sm_72.ll
|
||||
|
||||
@ -33,7 +33,7 @@
|
||||
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
|
||||
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,NOMMA
|
||||
# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
|
||||
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
|
||||
# RUN: | FileCheck %t-ptx63-sm_75.ll
|
||||
|
||||
@ -42,10 +42,28 @@
|
||||
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
|
||||
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT
|
||||
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
|
||||
# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
|
||||
# RUN: | FileCheck %t-ptx64-sm_70.ll
|
||||
|
||||
# Check all variants of instructions supported by PTX65 on SM75+
|
||||
# RUN: python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
|
||||
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
|
||||
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS
|
||||
# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
|
||||
# RUN: | FileCheck %t-ptx65-sm_75.ll
|
||||
|
||||
# Check all variants of instructions supported by PTX70 on SM80+
|
||||
# RUN: python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
|
||||
# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
|
||||
# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
|
||||
# RUN: --check-prefixes=INTRINSICS
|
||||
# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
|
||||
# RUN: | FileCheck %t-ptx70-sm_80.ll
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
@ -56,19 +74,23 @@ class MMAType:
|
||||
def __init__(self, ptx_type):
|
||||
self.ptx_type = ptx_type
|
||||
self.llvm_type = {
|
||||
"f16" : "<2 x half>",
|
||||
"f32" : "float",
|
||||
"s32" : "i32",
|
||||
"s8" : "i32",
|
||||
"u8" : "i32",
|
||||
"s4" : "i32",
|
||||
"u4" : "i32",
|
||||
"b1" : "i32",
|
||||
"f16" : "<2 x half>",
|
||||
"f32" : "float",
|
||||
"f64" : "double",
|
||||
"s32" : "i32",
|
||||
"s8" : "i32",
|
||||
"u8" : "i32",
|
||||
"s4" : "i32",
|
||||
"u4" : "i32",
|
||||
"b1" : "i32",
|
||||
"bf16" : "i32",
|
||||
"tf32" : "i32",
|
||||
}[ptx_type];
|
||||
|
||||
self.ptx_reg_pattern = {
|
||||
"f16" : "%hh[0-9]+",
|
||||
"f32" : "%f[0-9]+",
|
||||
"f64" : "%fd[0-9]+",
|
||||
}.get(ptx_type, "%r[0-9]+")
|
||||
|
||||
def __repr__(self):
|
||||
@ -78,16 +100,8 @@ class MMAFrag:
|
||||
def __init__(self, geom, frag, ptx_elt_type):
|
||||
self.geom = geom
|
||||
self.frag = frag
|
||||
self.is_mma = True if geom == "m8n8k4" else False;
|
||||
self.mma_type = MMAType(ptx_elt_type);
|
||||
self.nregs = {
|
||||
"a:f16" : 2 if self.is_mma else 8,
|
||||
"b:f16" : 2 if self.is_mma else 8,
|
||||
"c:f16" : 4,
|
||||
"d:f16" : 4,
|
||||
"c:f32" : 8,
|
||||
"d:f32" : 8,
|
||||
}.get("%s:%s" % (frag, ptx_elt_type), {
|
||||
# u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
"m16n16k16:a:u8" : 2,
|
||||
"m16n16k16:a:s8" : 2,
|
||||
@ -110,18 +124,123 @@ class MMAFrag:
|
||||
"m32n8k16:c:s32" : 8,
|
||||
"m32n8k16:d:s32" : 8,
|
||||
|
||||
# u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
|
||||
"m8n8k128:a:b1" : 1,
|
||||
"m8n8k16:a:u8": 1,
|
||||
"m8n8k16:a:s8": 1,
|
||||
"m8n8k16:b:u8": 1,
|
||||
"m8n8k16:b:s8": 1,
|
||||
"m8n8k16:c:s32": 2,
|
||||
"m8n8k16:d:s32": 2,
|
||||
|
||||
"m16n8k16:a:u8": 2,
|
||||
"m16n8k16:a:s8": 2,
|
||||
"m16n8k16:b:u8": 1,
|
||||
"m16n8k16:b:s8": 1,
|
||||
"m16n8k16:c:s32": 4,
|
||||
"m16n8k16:d:s32": 4,
|
||||
|
||||
"m16n8k32:a:u8": 4,
|
||||
"m16n8k32:a:s8": 4,
|
||||
"m16n8k32:b:u8": 2,
|
||||
"m16n8k32:b:s8": 2,
|
||||
"m16n8k32:c:s32": 4,
|
||||
"m16n8k32:d:s32": 4,
|
||||
|
||||
# u4/s4 -> s32 @ m8n8k32 (u4/s4)
|
||||
"m8n8k32:a:u4" : 1,
|
||||
"m8n8k32:a:s4" : 1,
|
||||
"m8n8k128:b:b1" : 1,
|
||||
"m8n8k32:b:u4" : 1,
|
||||
"m8n8k32:b:s4" : 1,
|
||||
"m8n8k128:c:s32" : 2,
|
||||
"m8n8k128:d:s32" : 2,
|
||||
"m8n8k32:c:s32" : 2,
|
||||
"m8n8k32:d:s32" : 2,
|
||||
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
|
||||
|
||||
"m16n8k32:a:u4" : 2,
|
||||
"m16n8k32:a:s4" : 2,
|
||||
"m16n8k32:b:u4" : 1,
|
||||
"m16n8k32:b:s4" : 1,
|
||||
"m16n8k32:c:s32" : 4,
|
||||
"m16n8k32:d:s32" : 4,
|
||||
|
||||
"m16n8k64:a:u4" : 4,
|
||||
"m16n8k64:a:s4" : 4,
|
||||
"m16n8k64:b:u4" : 2,
|
||||
"m16n8k64:b:s4" : 2,
|
||||
"m16n8k64:c:s32" : 4,
|
||||
"m16n8k64:d:s32" : 4,
|
||||
|
||||
# b1 -> s32 @ m8n8k128(b1)
|
||||
"m8n8k128:a:b1" : 1,
|
||||
"m8n8k128:b:b1" : 1,
|
||||
"m8n8k128:c:s32" : 2,
|
||||
"m8n8k128:d:s32" : 2,
|
||||
|
||||
"m16n8k128:a:b1" : 2,
|
||||
"m16n8k128:b:b1" : 1,
|
||||
"m16n8k128:c:s32" : 4,
|
||||
"m16n8k128:d:s32" : 4,
|
||||
|
||||
"m16n8k256:a:b1" : 4,
|
||||
"m16n8k256:b:b1" : 2,
|
||||
"m16n8k256:c:s32" : 4,
|
||||
"m16n8k256:d:s32" : 4,
|
||||
|
||||
# bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
||||
"m16n16k16:a:bf16" : 4,
|
||||
"m16n16k16:b:bf16" : 4,
|
||||
"m8n32k16:a:bf16" : 2,
|
||||
"m8n32k16:b:bf16" : 8,
|
||||
"m32n8k16:a:bf16" : 8,
|
||||
"m32n8k16:b:bf16" : 2,
|
||||
|
||||
"m16n8k16:a:bf16" : 4,
|
||||
"m16n8k16:b:bf16" : 2,
|
||||
"m16n8k16:c:f32" : 4,
|
||||
"m16n8k16:d:f32" : 4,
|
||||
"m16n8k8:a:bf16" : 2,
|
||||
"m16n8k8:b:bf16" : 1,
|
||||
"m16n8k8:c:f32" : 4,
|
||||
"m16n8k8:d:f32" : 4,
|
||||
|
||||
"m8n8k4:a:f64" : 1,
|
||||
"m8n8k4:b:f64" : 1,
|
||||
"m8n8k4:c:f64" : 2,
|
||||
"m8n8k4:d:f64" : 2,
|
||||
|
||||
# tf32 -> s32 @ m16n16k8
|
||||
"m16n16k8:a:tf32" : 4,
|
||||
"m16n16k8:b:tf32" : 4,
|
||||
|
||||
"m16n8k4:a:tf32" : 2,
|
||||
"m16n8k4:b:tf32" : 1,
|
||||
"m16n8k4:c:f32" : 4,
|
||||
"m16n8k4:d:f32" : 4,
|
||||
"m16n8k8:a:tf32" : 4,
|
||||
"m16n8k8:b:tf32" : 2,
|
||||
"m16n8k8:c:f32" : 4,
|
||||
"m16n8k8:d:f32" : 4,
|
||||
|
||||
"m8n8k4:a:f16": 2,
|
||||
"m8n8k4:b:f16": 2,
|
||||
"m16n8k8:a:f16": 2,
|
||||
"m16n8k8:b:f16": 1,
|
||||
"m16n8k8:c:f16": 2,
|
||||
"m16n8k8:d:f16": 2,
|
||||
"m16n8k8:c:f32": 4,
|
||||
"m16n8k8:d:f32": 4,
|
||||
"m16n8k16:a:f16": 4,
|
||||
"m16n8k16:b:f16": 2,
|
||||
"m16n8k16:c:f16": 2,
|
||||
"m16n8k16:d:f16": 2,
|
||||
"m16n8k16:c:f32": 4,
|
||||
"m16n8k16:d:f32": 4,
|
||||
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
|
||||
# All other FP shape/fragment/type combinations have the same size
|
||||
"a:f16" : 8,
|
||||
"b:f16" : 8,
|
||||
"c:f16" : 4,
|
||||
"d:f16" : 4,
|
||||
"c:f32" : 8,
|
||||
"d:f32" : 8,
|
||||
}.get("%s:%s" % (frag, ptx_elt_type), None))
|
||||
assert(self.nregs);
|
||||
|
||||
def __repr__(self):
|
||||
@ -153,9 +272,13 @@ def make_ldst_ops(geoms, frags, types):
|
||||
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
|
||||
in product(geoms, frags, types)]
|
||||
|
||||
def get_mma_ops():
|
||||
return (make_mma_ops(["m8n8k4"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
|
||||
def get_wmma_ops():
|
||||
return (make_mma_ops(["m16n16k8"],
|
||||
["tf32"], [], ["f32"], []) +
|
||||
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["bf16"], [], ["f32"], []) +
|
||||
make_mma_ops(["m8n8k4"],
|
||||
["f64"], [], ["f64"], []) +
|
||||
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
|
||||
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
@ -164,20 +287,38 @@ def get_mma_ops():
|
||||
["s4", "u4"], [], ["s32"], []) +
|
||||
make_mma_ops(["m8n8k128"],
|
||||
["b1"], [], ["s32"], []))
|
||||
|
||||
def get_mma_ops():
|
||||
return (make_mma_ops(["m8n8k4"],
|
||||
["f64"], [], ["f64"], []) +
|
||||
make_mma_ops(["m16n8k4", "m16n8k8"],
|
||||
["tf32"], [], ["f32"], []) +
|
||||
make_mma_ops(["m16n8k16", "m16n8k8"],
|
||||
["bf16"], [], ["f32"], []) +
|
||||
make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
|
||||
make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
|
||||
["s8", "u8"], ["s8", "u8"], ["s32"], []) +
|
||||
make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
|
||||
["s4", "u4"], ["s4", "u4"], ["s32"], []) +
|
||||
make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
|
||||
["b1"], [], ["s32"], []))
|
||||
|
||||
def get_ldst_ops(kind):
|
||||
ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["a", "b"], ["f16", "u8", "s8"]) +
|
||||
["a", "b"], ["f16", "u8", "s8", "bf16"]) +
|
||||
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
||||
["c", "d"], ["f16", "f32", "s32"]) +
|
||||
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
|
||||
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
|
||||
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
|
||||
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
|
||||
make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
|
||||
make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
|
||||
make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
|
||||
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
|
||||
|
||||
def is_geom_supported(geom):
|
||||
def is_wmma_geom_supported(geom):
|
||||
# geometries for FP and ints.
|
||||
if geom == "m8n8k4":
|
||||
return ptx_version >= 64
|
||||
if geom in ["m8n32k16", "m32n8k16"]:
|
||||
return ptx_version >= 61
|
||||
# geometries for sub-ints.
|
||||
@ -185,6 +326,21 @@ def is_geom_supported(geom):
|
||||
return ptx_version >= 63 and gpu_arch >= 75
|
||||
if geom == "m16n16k16":
|
||||
return ptx_version >= 60
|
||||
if geom == "m16n8k8":
|
||||
return ptx_version >= 65
|
||||
if geom in ["m16n16k8", "m8n8k4"]:
|
||||
return ptx_version >= 70
|
||||
assert(False) # Unexpected geometry.
|
||||
|
||||
def is_mma_geom_supported(geom):
|
||||
# geometries for FP and ints.
|
||||
if geom == "m8n8k4":
|
||||
return ptx_version >= 64
|
||||
if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
|
||||
return ptx_version >= 65
|
||||
if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
|
||||
"m16n8k128", "m16n8k256"]:
|
||||
return ptx_version >= 70
|
||||
assert(False) # Unexpected geometry.
|
||||
|
||||
def is_type_supported(ptx_type):
|
||||
@ -192,30 +348,63 @@ def is_type_supported(ptx_type):
|
||||
return ptx_version >= 63 and gpu_arch >= 72
|
||||
if ptx_type in ["s4", "u4", "b1"]:
|
||||
return ptx_version >= 63 and gpu_arch >= 75
|
||||
if ptx_type in ["bf16", "tf32", "f64"]:
|
||||
return ptx_version >= 70
|
||||
return ptx_version >= 60 and gpu_arch >= 70
|
||||
|
||||
def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
|
||||
if not (is_type_supported(op.a.mma_type.ptx_type)
|
||||
and is_wmma_geom_supported(op.a.geom)):
|
||||
return False
|
||||
|
||||
# rnd is only supported for FP64 WMMA
|
||||
if rnd and op.a.mma_type.ptx_type != "f64":
|
||||
return False
|
||||
|
||||
if satf:
|
||||
# satfinite for floating points was removed in PTX 6.5
|
||||
if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
|
||||
return False
|
||||
if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
|
||||
return False
|
||||
|
||||
# sub-integer require row/col layout.
|
||||
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
|
||||
return layout_a == "row" and layout_b == "col"
|
||||
return True
|
||||
|
||||
def is_mma_variant_supported(op, layout_a, layout_b, satf):
|
||||
if not (is_type_supported(op.a.mma_type.ptx_type)
|
||||
and is_geom_supported(op.a.geom)):
|
||||
and is_mma_geom_supported(op.a.geom)):
|
||||
return False
|
||||
if op.a.geom == "m8n8k4":
|
||||
if satf:
|
||||
return False
|
||||
if op.c.mma_type.ptx_type == "f32":
|
||||
# If C is f32, D must be, too.
|
||||
return op.d.mma_type.ptx_type == "f32"
|
||||
|
||||
# sub-integer require row/col layout, and no satf.
|
||||
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
|
||||
if op.a.mma_type.ptx_type == "b1" and satf:
|
||||
if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
|
||||
return False
|
||||
|
||||
# If the type of C is f32 then so must the type of D
|
||||
if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
|
||||
and op.d.mma_type.ptx_type != "f32"):
|
||||
return False
|
||||
|
||||
# A and B type must be the same. C and D type must be the same
|
||||
if (op.a.geom == "m16n8k8"
|
||||
and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
|
||||
or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
|
||||
return False
|
||||
|
||||
# C and D type must be the same
|
||||
if (op.a.geom == "m16n8k16"
|
||||
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
|
||||
return False
|
||||
|
||||
# Require row/col layout for all MMA except m8n8k4 on FP16
|
||||
if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
|
||||
return layout_a == "row" and layout_b == "col"
|
||||
return True
|
||||
|
||||
def is_ldst_variant_supported(frag, layout):
|
||||
if not (is_type_supported(frag.mma_type.ptx_type)
|
||||
and is_geom_supported(frag.geom)):
|
||||
and is_wmma_geom_supported(frag.geom)):
|
||||
return False
|
||||
if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
|
||||
# sub-integer require sm_75 and ptx63, row/col layout for a/b.
|
||||
@ -396,24 +585,37 @@ define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
|
||||
return generated_items
|
||||
|
||||
def mma_signature(op):
|
||||
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
|
||||
# int and sub-int ops are identified by input type.
|
||||
return op.a.mma_type.ptx_type
|
||||
else:
|
||||
# the rest are FP ops identified by accumulator & result type.
|
||||
if op.a.mma_type.ptx_type == "f16":
|
||||
# FP16 ops identified by accumulator & result type.
|
||||
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
||||
elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
|
||||
# other ops are identified by input types.
|
||||
return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
|
||||
else:
|
||||
# if input types are the same, it only appears once.
|
||||
return op.a.mma_type.ptx_type
|
||||
|
||||
def mma_ptx_signature(op):
|
||||
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
|
||||
# int and sub-int instructions encode all four types as D.A.B.C
|
||||
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
|
||||
if op.a.geom == "m8n8k4":
|
||||
return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
||||
else:
|
||||
# the rest are FP instructions use D.C
|
||||
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
||||
# Encode all four types as D.A.B.C
|
||||
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
|
||||
|
||||
def gen_wmma_mma_tests():
|
||||
def wmma_signature(op):
|
||||
if op.a.mma_type.ptx_type == "f16":
|
||||
# FP16 ops identified by accumulator & result type.
|
||||
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
||||
else:
|
||||
# other ops are identified by input type.
|
||||
return op.a.mma_type.ptx_type
|
||||
|
||||
def wmma_ptx_signature(op):
|
||||
if op.a.mma_type.ptx_type == "f16":
|
||||
# FP16 instructions use D.C
|
||||
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
||||
else:
|
||||
# other instructions encode all four types as D.A.B.C
|
||||
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
|
||||
|
||||
def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
|
||||
mma_template = """
|
||||
declare ${ret_ty} @${intrinsic}(
|
||||
${args});
|
||||
@ -431,10 +633,61 @@ define ${ret_ty} @test_${function}(
|
||||
ret ${ret_ty} %r;
|
||||
}
|
||||
"""
|
||||
wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
|
||||
wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
|
||||
mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}"
|
||||
mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}"
|
||||
|
||||
test_params = params
|
||||
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
|
||||
test_params["function"] = test_params["intrinsic"].replace(".", "_")
|
||||
test_params["instruction"] = Template(instruction_template).substitute(params)
|
||||
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
|
||||
test_params["check_a"] = check_pattern(op.a)
|
||||
test_params["check_b"] = check_pattern(op.b)
|
||||
test_params["check_c"] = check_pattern(op.c)
|
||||
test_params["check_d"] = check_pattern(op.d)
|
||||
args = ",\n ".join(make_wmma_slice_args(frag)
|
||||
for frag in (op.a, op.b, op.c))
|
||||
test_params["args"] = args
|
||||
print(Template(mma_template).substitute(test_params))
|
||||
return (test_params["intrinsic"], test_params["instruction"])
|
||||
|
||||
def gen_wmma_mma_tests():
|
||||
wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
|
||||
wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
|
||||
|
||||
generated_items=[]
|
||||
|
||||
for op, alayout, blayout, rnd, satf in product(
|
||||
get_wmma_ops(),
|
||||
["row","col"],
|
||||
["row","col"],
|
||||
[".rn", ".rz", ".rm", ".rp", ""],
|
||||
[".satfinite", ""]):
|
||||
|
||||
if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
|
||||
continue
|
||||
|
||||
params = {
|
||||
"aligned" : ".aligned" if ptx_version >= 63 else "",
|
||||
"alayout" : alayout,
|
||||
"blayout" : blayout,
|
||||
"intrinsic_signature" : wmma_signature(op),
|
||||
"ptx_signature" : wmma_ptx_signature(op),
|
||||
"satf" : satf,
|
||||
"rnd" : rnd,
|
||||
"geom" : op.a.geom,
|
||||
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
|
||||
}
|
||||
|
||||
intrinsic_template = wmma_intrinsic_template
|
||||
instruction_template = wmma_instruction_template
|
||||
|
||||
generated_items.append(common_mma_test_gen(params, op,
|
||||
intrinsic_template, instruction_template))
|
||||
|
||||
return generated_items
|
||||
|
||||
def gen_mma_tests():
|
||||
mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
|
||||
mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
|
||||
|
||||
generated_items=[]
|
||||
|
||||
@ -458,28 +711,11 @@ define ${ret_ty} @test_${function}(
|
||||
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
|
||||
}
|
||||
|
||||
if op.a.geom == "m8n8k4":
|
||||
intrinsic_template = mma_intrinsic_template
|
||||
instruction_template = mma_instruction_template
|
||||
else:
|
||||
intrinsic_template = wmma_intrinsic_template
|
||||
instruction_template = wmma_instruction_template
|
||||
intrinsic_template = mma_intrinsic_template
|
||||
instruction_template = mma_instruction_template
|
||||
|
||||
test_params = params
|
||||
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
|
||||
test_params["function"] = test_params["intrinsic"].replace(".", "_")
|
||||
test_params["instruction"] = Template(instruction_template).substitute(params)
|
||||
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
|
||||
test_params["check_a"] = check_pattern(op.a)
|
||||
test_params["check_b"] = check_pattern(op.b)
|
||||
test_params["check_c"] = check_pattern(op.c)
|
||||
test_params["check_d"] = check_pattern(op.d)
|
||||
args = ",\n ".join(make_wmma_slice_args(frag)
|
||||
for frag in (op.a, op.b, op.c))
|
||||
test_params["args"] = args
|
||||
print(Template(mma_template).substitute(test_params))
|
||||
generated_items.append((test_params["intrinsic"],
|
||||
test_params["instruction"]))
|
||||
generated_items.append(common_mma_test_gen(params, op,
|
||||
intrinsic_template, instruction_template))
|
||||
|
||||
return generated_items
|
||||
|
||||
@ -497,6 +733,8 @@ def gen_check_unsupported_ops(items):
|
||||
; NOINT-NOT: .{{s32|s8}}
|
||||
; NOSUBINT-NOT: {{s4|u4|b1}}
|
||||
; NOMMA-NOT: .m8n8k4.
|
||||
; NOALTFLOAT-NOT: .{{bf16|tf32}}
|
||||
; NODOUBLE-NOT: .f64
|
||||
|
||||
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
|
||||
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
|
||||
@ -543,10 +781,61 @@ def gen_check_unsupported_ops(items):
|
||||
; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
|
||||
; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
|
||||
|
||||
; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
|
||||
; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
|
||||
; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
|
||||
; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
|
||||
; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
|
||||
; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
|
||||
; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
|
||||
; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
|
||||
|
||||
; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
|
||||
; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
|
||||
; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
|
||||
|
||||
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
|
||||
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
|
||||
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
|
||||
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
|
||||
|
||||
; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
|
||||
; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
|
||||
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
|
||||
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
|
||||
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
|
||||
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
|
||||
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
|
||||
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
|
||||
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
|
||||
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
|
||||
|
||||
; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
|
||||
; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
|
||||
; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
|
||||
; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
|
||||
; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
|
||||
; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
|
||||
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
|
||||
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
|
||||
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
|
||||
; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
|
||||
; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
|
||||
; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
|
||||
; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
|
||||
;
|
||||
|
||||
""")
|
||||
@ -561,6 +850,7 @@ def gen_tests():
|
||||
items = gen_wmma_load_tests()
|
||||
items += gen_wmma_store_tests()
|
||||
items += gen_wmma_mma_tests()
|
||||
items += gen_mma_tests()
|
||||
gen_check_unsupported_ops(items)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
Loading…
Reference in New Issue
Block a user