From 983b1c673591b09e179c0a40187e96bf69f4e4ab Mon Sep 17 00:00:00 2001 From: "Liu, Chen3" Date: Wed, 24 Feb 2021 11:25:35 +0800 Subject: [PATCH] [X86] Support amx-bf16 intrinsic. Adding support for intrinsics of AMX-BF16. This patch alse fix a bug that AMX-INT8 instructions will be selected with wrong predicate. Differential Revision: https://reviews.llvm.org/D97358 --- include/llvm/IR/IntrinsicsX86.td | 6 ++++++ lib/Target/X86/X86ExpandPseudo.cpp | 12 +++++++----- lib/Target/X86/X86ISelDAGToDAG.cpp | 2 +- lib/Target/X86/X86InstrAMX.td | 10 ++++++++++ lib/Target/X86/X86LowerAMXType.cpp | 3 ++- lib/Target/X86/X86PreTileConfig.cpp | 2 ++ lib/Target/X86/X86RegisterInfo.cpp | 1 + test/CodeGen/X86/AMX/amx-tile-basic.ll | 7 +++++-- 8 files changed, 34 insertions(+), 9 deletions(-) diff --git a/include/llvm/IR/IntrinsicsX86.td b/include/llvm/IR/IntrinsicsX86.td index 2c1202cc2a0..643018b0eed 100644 --- a/include/llvm/IR/IntrinsicsX86.td +++ b/include/llvm/IR/IntrinsicsX86.td @@ -5079,6 +5079,12 @@ let TargetPrefix = "x86" in { GCCBuiltin<"__builtin_ia32_tilezero_internal">, Intrinsic<[llvm_x86amx_ty], [llvm_i16_ty, llvm_i16_ty], []>; + def int_x86_tdpbf16ps_internal : + GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; } //===----------------------------------------------------------------------===// diff --git a/lib/Target/X86/X86ExpandPseudo.cpp b/lib/Target/X86/X86ExpandPseudo.cpp index fc4e9eb4a4b..ab0062e027a 100644 --- a/lib/Target/X86/X86ExpandPseudo.cpp +++ b/lib/Target/X86/X86ExpandPseudo.cpp @@ -470,16 +470,18 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB, case X86::PTDPBSSDV: case X86::PTDPBSUDV: case X86::PTDPBUSDV: - case X86::PTDPBUUDV: { + case X86::PTDPBUUDV: + case X86::PTDPBF16PSV: { MI.untieRegOperand(4); for (unsigned i = 3; i > 0; --i) MI.RemoveOperand(i); unsigned Opc; switch (Opcode) { - case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break; - case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break; - case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break; - case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; + case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break; + case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break; + case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break; + case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; + case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; default: llvm_unreachable("Impossible Opcode!"); } MI.setDesc(TII->get(Opc)); diff --git a/lib/Target/X86/X86ISelDAGToDAG.cpp b/lib/Target/X86/X86ISelDAGToDAG.cpp index bebd430af6a..f34d34f8a34 100644 --- a/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4626,7 +4626,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) { case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: case Intrinsic::x86_tdpbuud_internal: { - if (!Subtarget->hasAMXTILE()) + if (!Subtarget->hasAMXINT8()) break; SDValue Chain = Node->getOperand(0); unsigned Opc; diff --git a/lib/Target/X86/X86InstrAMX.td b/lib/Target/X86/X86InstrAMX.td index b93aab30161..6731599b909 100644 --- a/lib/Target/X86/X86InstrAMX.td +++ b/lib/Target/X86/X86InstrAMX.td @@ -138,6 +138,16 @@ let Predicates = [HasAMXBF16, In64BitMode] in { "tdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", []>, VEX_4V, T8XS; + // Pseduo instruction for RA. + let Constraints = "$src4 = $dst" in + def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_tdpbf16ps_internal GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6))]>; + let usesCustomInserter = 1 in { // Pseudo instructions, using immediates instead of tile registers. // To be translated to the actual instructions in X86ISelLowering.cpp diff --git a/lib/Target/X86/X86LowerAMXType.cpp b/lib/Target/X86/X86LowerAMXType.cpp index 3fdcf1607d2..5e844a083e7 100644 --- a/lib/Target/X86/X86LowerAMXType.cpp +++ b/lib/Target/X86/X86LowerAMXType.cpp @@ -70,7 +70,8 @@ static std::pair getShape(IntrinsicInst *II, unsigned OpNo) { case Intrinsic::x86_tdpbssd_internal: case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: - case Intrinsic::x86_tdpbuud_internal: { + case Intrinsic::x86_tdpbuud_internal: + case Intrinsic::x86_tdpbf16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/lib/Target/X86/X86PreTileConfig.cpp b/lib/Target/X86/X86PreTileConfig.cpp index 90b421b44d7..1c91e87e69d 100644 --- a/lib/Target/X86/X86PreTileConfig.cpp +++ b/lib/Target/X86/X86PreTileConfig.cpp @@ -159,6 +159,7 @@ static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); ShapeT Shape(&MO1, &MO2, MRI); @@ -256,6 +257,7 @@ static bool isAMXInstruction(MachineBasicBlock::iterator MII) { case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: return true; } } diff --git a/lib/Target/X86/X86RegisterInfo.cpp b/lib/Target/X86/X86RegisterInfo.cpp index 00bb73fa2d9..9865216ee36 100644 --- a/lib/Target/X86/X86RegisterInfo.cpp +++ b/lib/Target/X86/X86RegisterInfo.cpp @@ -888,6 +888,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); ShapeT Shape(&MO1, &MO2, MRI); diff --git a/test/CodeGen/X86/AMX/amx-tile-basic.ll b/test/CodeGen/X86/AMX/amx-tile-basic.ll index ebb6ee5bc23..095eb8e6ea8 100644 --- a/test/CodeGen/X86/AMX/amx-tile-basic.ll +++ b/test/CodeGen/X86/AMX/amx-tile-basic.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile -mattr=+avx512f -verify-machineinstrs | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-int8,+amx-bf16,+avx512f -verify-machineinstrs | FileCheck %s define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { ; CHECK-LABEL: test_amx: @@ -22,6 +22,7 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { ; CHECK-NEXT: tdpbsud %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tdpbusd %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tdpbuud %tmm2, %tmm1, %tmm0 +; CHECK-NEXT: tdpbf16ps %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tilestored %tmm0, (%rdi,%rdx) ; CHECK-NEXT: tilerelease ; CHECK-NEXT: vzeroupper @@ -33,7 +34,8 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { %d1 = call x86_amx @llvm.x86.tdpbsud.internal(i16 8, i16 8, i16 8, x86_amx %d0, x86_amx %a, x86_amx %b) %d2 = call x86_amx @llvm.x86.tdpbusd.internal(i16 8, i16 8, i16 8, x86_amx %d1, x86_amx %a, x86_amx %b) %d3 = call x86_amx @llvm.x86.tdpbuud.internal(i16 8, i16 8, i16 8, x86_amx %d2, x86_amx %a, x86_amx %b) - call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d3) + %d4 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 8, i16 8, i16 8, x86_amx %d3, x86_amx %a, x86_amx %b) + call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d4) ret void } @@ -44,4 +46,5 @@ declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_ declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)