From 535fe28ef48fc2443220fd96f7c98a07f29255f9 Mon Sep 17 00:00:00 2001 From: Xiang1 Zhang Date: Tue, 7 Jul 2020 09:50:17 +0800 Subject: [PATCH] [X86-64] Support Intel AMX Intrinsic INTEL ADVANCED MATRIX EXTENSIONS (AMX). AMX is a new programming paradigm, it has a set of 2-dimensional registers (TILES) representing sub-arrays from a larger 2-dimensional memory image and operate on TILES. These intrinsics use direct TMM register number as its params. Spec can be found in Chapter 3 here https://software.intel.com/content/www/us/en/develop/download/intel-architecture-instruction-set-extensions-programming-reference.html Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D83111 --- include/llvm/IR/IntrinsicsX86.td | 29 +++++++++ lib/Target/X86/X86ISelDAGToDAG.cpp | 33 ++++++++++- lib/Target/X86/X86ISelLowering.cpp | 65 +++++++++++++++++++++ lib/Target/X86/X86InstrAMX.td | 49 +++++++++++++++- test/CodeGen/X86/AMX/amx-bf16-intrinsics.ll | 13 +++++ test/CodeGen/X86/AMX/amx-int8-intrinsics.ll | 24 ++++++++ test/CodeGen/X86/AMX/amx-tile-intrinsics.ll | 36 ++++++++++++ 7 files changed, 245 insertions(+), 4 deletions(-) create mode 100644 test/CodeGen/X86/AMX/amx-bf16-intrinsics.ll create mode 100644 test/CodeGen/X86/AMX/amx-int8-intrinsics.ll create mode 100644 test/CodeGen/X86/AMX/amx-tile-intrinsics.ll diff --git a/include/llvm/IR/IntrinsicsX86.td b/include/llvm/IR/IntrinsicsX86.td index b3bf1872059..3f86fd075d3 100644 --- a/include/llvm/IR/IntrinsicsX86.td +++ b/include/llvm/IR/IntrinsicsX86.td @@ -4948,3 +4948,32 @@ let TargetPrefix = "x86" in { def int_x86_xresldtrk : GCCBuiltin<"__builtin_ia32_xresldtrk">, Intrinsic<[], [], []>; } +//===----------------------------------------------------------------------===// +// AMX - Intel AMX extensions + +let TargetPrefix = "x86" in { + def int_x86_ldtilecfg : GCCBuiltin<"__builtin_ia32_tile_loadconfig">, + Intrinsic<[], [llvm_ptr_ty], []>; + def int_x86_sttilecfg : GCCBuiltin<"__builtin_ia32_tile_storeconfig">, + Intrinsic<[], [llvm_ptr_ty], []>; + def int_x86_tilerelease : GCCBuiltin<"__builtin_ia32_tilerelease">, + Intrinsic<[], [], []>; + def int_x86_tilezero : GCCBuiltin<"__builtin_ia32_tilezero">, + Intrinsic<[], [llvm_i8_ty], []>; + def int_x86_tileloadd64 : GCCBuiltin<"__builtin_ia32_tileloadd64">, + Intrinsic<[], [llvm_i8_ty, llvm_ptr_ty, llvm_i64_ty], []>; + def int_x86_tileloaddt164 : GCCBuiltin<"__builtin_ia32_tileloaddt164">, + Intrinsic<[], [llvm_i8_ty, llvm_ptr_ty, llvm_i64_ty], []>; + def int_x86_tilestored64 : GCCBuiltin<"__builtin_ia32_tilestored64">, + Intrinsic<[], [llvm_i8_ty, llvm_ptr_ty, llvm_i64_ty], []>; + def int_x86_tdpbssd : GCCBuiltin<"__builtin_ia32_tdpbssd">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; + def int_x86_tdpbsud : GCCBuiltin<"__builtin_ia32_tdpbsud">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; + def int_x86_tdpbusd : GCCBuiltin<"__builtin_ia32_tdpbusd">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; + def int_x86_tdpbuud : GCCBuiltin<"__builtin_ia32_tdpbuud">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; + def int_x86_tdpbf16ps : GCCBuiltin<"__builtin_ia32_tdpbf16ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; +} diff --git a/lib/Target/X86/X86ISelDAGToDAG.cpp b/lib/Target/X86/X86ISelDAGToDAG.cpp index 5a57ca7646f..fb285376c58 100644 --- a/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4435,8 +4435,39 @@ void X86DAGToDAGISel::Select(SDNode *Node) { break; } + case Intrinsic::x86_tileloadd64: + case Intrinsic::x86_tileloaddt164: + case Intrinsic::x86_tilestored64: { + if (!Subtarget->hasAMXTILE()) + break; + unsigned Opc; + switch (IntNo) { + default: llvm_unreachable("Unexpected intrinsic!"); + case Intrinsic::x86_tileloadd64: Opc = X86::PTILELOADD; break; + case Intrinsic::x86_tileloaddt164: Opc = X86::PTILELOADDT1; break; + case Intrinsic::x86_tilestored64: Opc = X86::PTILESTORED; break; + } + // FIXME: Match displacement and scale. + unsigned TIndex = Node->getConstantOperandVal(2); + SDValue TReg = getI8Imm(TIndex, dl); + SDValue Base = Node->getOperand(3); + SDValue Scale = getI8Imm(1, dl); + SDValue Index = Node->getOperand(4); + SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); + SDValue Segment = CurDAG->getRegister(0, MVT::i16); + SDValue Chain = Node->getOperand(0); + MachineSDNode *CNode; + if (Opc == X86::PTILESTORED) { + SDValue Ops[] = { Base, Scale, Index, Disp, Segment, TReg, Chain }; + CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); + } else { + SDValue Ops[] = { TReg, Base, Scale, Index, Disp, Segment, Chain }; + CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); + } + ReplaceNode(Node, CNode); + return; + } } - break; } case ISD::BRIND: { diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 88a563720c2..d7a45f6fb7c 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -33044,6 +33044,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, const TargetInstrInfo *TII = Subtarget.getInstrInfo(); DebugLoc DL = MI.getDebugLoc(); + auto TMMImmToTMMReg = [](unsigned Imm) { + assert (Imm < 8 && "Illegal tmm index"); + return X86::TMM0 + Imm; + }; switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); case X86::TLS_addr32: @@ -33326,6 +33330,67 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); return BB; } + case X86::PTDPBSSD: + case X86::PTDPBSUD: + case X86::PTDPBUSD: + case X86::PTDPBUUD: + case X86::PTDPBF16PS: { + const DebugLoc &DL = MI.getDebugLoc(); + unsigned Opc; + switch (MI.getOpcode()) { + case X86::PTDPBSSD: Opc = X86::TDPBSSD; break; + case X86::PTDPBSUD: Opc = X86::TDPBSUD; break; + case X86::PTDPBUSD: Opc = X86::TDPBUSD; break; + case X86::PTDPBUUD: Opc = X86::TDPBUUD; break; + case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break; + } + + MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef); + + MI.eraseFromParent(); // The pseudo is gone now. + return BB; + } + case X86::PTILEZERO: { + const DebugLoc &DL = MI.getDebugLoc(); + unsigned Imm = MI.getOperand(0).getImm(); + BuildMI(*BB, MI, DL, TII->get(X86::TILEZERO), TMMImmToTMMReg(Imm)); + MI.eraseFromParent(); // The pseudo is gone now. + return BB; + } + case X86::PTILELOADD: + case X86::PTILELOADDT1: + case X86::PTILESTORED: { + const DebugLoc &DL = MI.getDebugLoc(); + unsigned Opc; + switch (MI.getOpcode()) { + case X86::PTILELOADD: Opc = X86::TILELOADD; break; + case X86::PTILELOADDT1: Opc = X86::TILELOADDT1; break; + case X86::PTILESTORED: Opc = X86::TILESTORED; break; + } + + MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); + unsigned CurOp = 0; + if (Opc != X86::TILESTORED) + MIB.addReg(TMMImmToTMMReg(MI.getOperand(CurOp++).getImm()), + RegState::Define); + + MIB.add(MI.getOperand(CurOp++)); // base + MIB.add(MI.getOperand(CurOp++)); // scale + MIB.add(MI.getOperand(CurOp++)); // index -- stride + MIB.add(MI.getOperand(CurOp++)); // displacement + MIB.add(MI.getOperand(CurOp++)); // segment + + if (Opc == X86::TILESTORED) + MIB.addReg(TMMImmToTMMReg(MI.getOperand(CurOp++).getImm()), + RegState::Undef); + + MI.eraseFromParent(); // The pseudo is gone now. + return BB; + } } } diff --git a/lib/Target/X86/X86InstrAMX.td b/lib/Target/X86/X86InstrAMX.td index deefb3eecf3..e26dd5050a2 100644 --- a/lib/Target/X86/X86InstrAMX.td +++ b/lib/Target/X86/X86InstrAMX.td @@ -18,9 +18,11 @@ let Predicates = [HasAMXTILE, In64BitMode] in { let SchedRW = [WriteSystem] in { let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in def LDTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src), - "ldtilecfg\t$src", []>, VEX, T8PS; + "ldtilecfg\t$src", + [(int_x86_ldtilecfg addr:$src)]>, VEX, T8PS; def STTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src), - "sttilecfg\t$src", []>, VEX, T8PD; + "sttilecfg\t$src", + [(int_x86_sttilecfg addr:$src)]>, VEX, T8PD; def TILELOADD : I<0x4b, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src), "tileloadd\t{$src, $dst|$dst, $src}", []>, @@ -31,7 +33,7 @@ let Predicates = [HasAMXTILE, In64BitMode] in { VEX, T8PD; let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in def TILERELEASE : I<0x49, MRM_C0, (outs), (ins), - "tilerelease", []>, VEX, T8PS; + "tilerelease", [(int_x86_tilerelease)]>, VEX, T8PS; def TILESTORED : I<0x4b, MRMDestMemFSIB, (outs), (ins sibmem:$dst, TILE:$src), "tilestored\t{$src, $dst|$dst, $src}", []>, @@ -39,6 +41,17 @@ let Predicates = [HasAMXTILE, In64BitMode] in { def TILEZERO : I<0x49, MRMr0, (outs TILE:$dst), (ins), "tilezero\t$dst", []>, VEX, T8XD; + + let usesCustomInserter = 1 in { + // Pseudo instructions, using immediates instead of tile registers. + // To be translated to the actual instructions in X86ISelLowering.cpp + def PTILELOADD : PseudoI<(outs), (ins u8imm:$src1, sibmem:$src2), []>; + def PTILELOADDT1 : PseudoI<(outs), (ins u8imm:$src1, + sibmem:$src2), []>; + def PTILESTORED : PseudoI<(outs), (ins i8mem:$dst, u8imm:$src), []>; + def PTILEZERO : PseudoI<(outs), (ins u8imm:$src), + [(int_x86_tilezero imm:$src)]>; + } } // SchedRW } // HasAMXTILE @@ -62,6 +75,27 @@ let Predicates = [HasAMXINT8, In64BitMode] in { "tdpbuud\t{$src3, $src2, $dst|$dst, $src2, $src3}", []>, VEX_4V, T8PS; } + + let usesCustomInserter = 1 in { + // Pseudo instructions, using immediates instead of tile registers. + // To be translated to the actual instructions in X86ISelLowering.cpp + def PTDPBSSD : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tdpbssd imm:$src1, + imm:$src2, imm:$src3)]>; + def PTDPBSUD : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tdpbsud imm:$src1, + imm:$src2, imm:$src3)]>; + def PTDPBUSD : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tdpbusd imm:$src1, + imm:$src2, imm:$src3)]>; + def PTDPBUUD : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tdpbuud imm:$src1, + imm:$src2, imm:$src3)]>; + } } } // HasAMXTILE @@ -72,5 +106,14 @@ let Predicates = [HasAMXBF16, In64BitMode] in { (ins TILE:$src1, TILE:$src2, TILE:$src3), "tdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", []>, VEX_4V, T8XS; + + let usesCustomInserter = 1 in { + // Pseudo instructions, using immediates instead of tile registers. + // To be translated to the actual instructions in X86ISelLowering.cpp + def PTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tdpbf16ps imm:$src1, + imm:$src2, imm:$src3)]>; + } } } // HasAMXTILE, HasAMXBF16 diff --git a/test/CodeGen/X86/AMX/amx-bf16-intrinsics.ll b/test/CodeGen/X86/AMX/amx-bf16-intrinsics.ll new file mode 100644 index 00000000000..a415d9c1524 --- /dev/null +++ b/test/CodeGen/X86/AMX/amx-bf16-intrinsics.ll @@ -0,0 +1,13 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile -mattr=+amx-bf16 -verify-machineinstrs | FileCheck %s + +define void @test_amx() { +; CHECK-LABEL: test_amx: +; CHECK: # %bb.0: +; CHECK-NEXT: tdpbf16ps %tmm7, %tmm4, %tmm3 +; CHECK-NEXT: retq + call void @llvm.x86.tdpbf16ps(i8 3, i8 4, i8 7) + ret void +} + +declare void @llvm.x86.tdpbf16ps(i8 %tile0, i8 %tile1, i8 %tile2) diff --git a/test/CodeGen/X86/AMX/amx-int8-intrinsics.ll b/test/CodeGen/X86/AMX/amx-int8-intrinsics.ll new file mode 100644 index 00000000000..49e69aeab51 --- /dev/null +++ b/test/CodeGen/X86/AMX/amx-int8-intrinsics.ll @@ -0,0 +1,24 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s + +define void @test_amx() { +; CHECK-LABEL: test_amx: +; CHECK: # %bb.0: + call void @llvm.x86.tdpbssd(i8 3, i8 4, i8 7) +; CHECK-NEXT: tdpbssd %tmm7, %tmm4, %tmm3 + + call void @llvm.x86.tdpbsud(i8 3, i8 4, i8 7) +; CHECK-NEXT: tdpbsud %tmm7, %tmm4, %tmm3 + + call void @llvm.x86.tdpbusd(i8 3, i8 0, i8 7) +; CHECK-NEXT: tdpbusd %tmm7, %tmm0, %tmm3 + + call void @llvm.x86.tdpbuud(i8 3, i8 4, i8 1) +; CHECK-NEXT: tdpbuud %tmm1, %tmm4, %tmm3 + ret void +} + +declare void @llvm.x86.tdpbssd(i8 %tile0, i8 %tile1, i8 %tile2) +declare void @llvm.x86.tdpbsud(i8 %tile0, i8 %tile1, i8 %tile2) +declare void @llvm.x86.tdpbusd(i8 %tile0, i8 %tile1, i8 %tile2) +declare void @llvm.x86.tdpbuud(i8 %tile0, i8 %tile1, i8 %tile2) diff --git a/test/CodeGen/X86/AMX/amx-tile-intrinsics.ll b/test/CodeGen/X86/AMX/amx-tile-intrinsics.ll new file mode 100644 index 00000000000..6b8e040abb9 --- /dev/null +++ b/test/CodeGen/X86/AMX/amx-tile-intrinsics.ll @@ -0,0 +1,36 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile -verify-machineinstrs | FileCheck %s + +define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { +; CHECK-LABEL: test_amx: +; CHECK: # %bb.0: + call void @llvm.x86.ldtilecfg(i8* %pointer) +; CHECK-NEXT: ldtilecfg (%rdi) + + call void @llvm.x86.sttilecfg(i8* %pointer) +; CHECK-NEXT: sttilecfg (%rdi) + + call void @llvm.x86.tilerelease() +; CHECK-NEXT: tilerelease + + call void @llvm.x86.tilezero(i8 3) +; CHECK-NEXT: tilezero %tmm3 + + call void @llvm.x86.tileloadd64(i8 3, i8* %base, i64 %stride) +; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm3 + + call void @llvm.x86.tileloaddt164(i8 3, i8* %base, i64 %stride) +; CHECK-NEXT: tileloaddt1 (%rsi,%rdx), %tmm3 + + call void @llvm.x86.tilestored64(i8 3, i8* %base, i64 %stride) +; CHECK-NEXT: tilestored %tmm3, (%rsi,%rdx) + ret void +} + +declare void @llvm.x86.tileloadd64(i8 %tile, i8* %base, i64 %stride) +declare void @llvm.x86.tileloaddt164(i8 %tile, i8* %base, i64 %stride) +declare void @llvm.x86.tilestored64(i8 %tile, i8* %base, i64 %stride) +declare void @llvm.x86.ldtilecfg(i8* %pointer) +declare void @llvm.x86.sttilecfg(i8* %pointer) +declare void @llvm.x86.tilerelease() +declare void @llvm.x86.tilezero(i8 %tile)