mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2025-01-31 20:51:52 +01:00
[AArch32] Armv8.6-a Matrix Mult Assembly + Intrinsics
This patch upstreams support for the Armv8.6-a Matrix Multiplication Extension. A summary of the features can be found here: https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a This patch includes: - Assembly support for AArch32 - Intrinsics Support for AArch32 Neon Intrinsics for Matrix Multiplication Note: these extensions are optional in the 8.6a architecture and so have to be enabled by default No additional IR types or C Types are needed for this extension. This is part of a patch series, starting with BFloat16 support and the other components in the armv8.6a extension (in previous patches linked in phabricator) Based on work by: - Luke Geeson - Oliver Stannard - Luke Cheeseman Reviewers: t.p.northover, miyuki Reviewed By: miyuki Subscribers: miyuki, ostannard, kristof.beyls, hiraditya, danielkiss, cfe-commits Tags: #clang Differential Revision: https://reviews.llvm.org/D77872
This commit is contained in:
parent
2ce0a5d73d
commit
09ef958be1
@ -773,6 +773,19 @@ class Neon_Dot_Intrinsic
|
||||
def int_arm_neon_udot : Neon_Dot_Intrinsic;
|
||||
def int_arm_neon_sdot : Neon_Dot_Intrinsic;
|
||||
|
||||
// v8.6-A Matrix Multiply Intrinsics
|
||||
class Neon_MatMul_Intrinsic
|
||||
: Intrinsic<[llvm_anyvector_ty],
|
||||
[LLVMMatchType<0>, llvm_anyvector_ty,
|
||||
LLVMMatchType<1>],
|
||||
[IntrNoMem]>;
|
||||
def int_arm_neon_ummla : Neon_MatMul_Intrinsic;
|
||||
def int_arm_neon_smmla : Neon_MatMul_Intrinsic;
|
||||
def int_arm_neon_usmmla : Neon_MatMul_Intrinsic;
|
||||
def int_arm_neon_usdot : Neon_Dot_Intrinsic;
|
||||
|
||||
// v8.6-A Bfloat Intrinsics
|
||||
|
||||
def int_arm_cls: Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
|
||||
def int_arm_cls64: Intrinsic<[llvm_i32_ty], [llvm_i64_ty], [IntrNoMem]>;
|
||||
|
||||
|
@ -428,6 +428,9 @@ def FeatureSB : SubtargetFeature<"sb", "HasSB", "true",
|
||||
def FeatureBF16 : SubtargetFeature<"bf16", "HasBF16", "true",
|
||||
"Enable support for BFloat16 instructions", [FeatureNEON]>;
|
||||
|
||||
def FeatureMatMulInt8 : SubtargetFeature<"i8mm", "HasMatMulInt8",
|
||||
"true", "Enable Matrix Multiply Int8 Extension", [FeatureNEON]>;
|
||||
|
||||
// Armv8.1-M extensions
|
||||
|
||||
def FeatureLOB : SubtargetFeature<"lob", "HasLOB", "true",
|
||||
@ -529,7 +532,8 @@ def HasV8_5aOps : SubtargetFeature<"v8.5a", "HasV8_5aOps", "true",
|
||||
|
||||
def HasV8_6aOps : SubtargetFeature<"v8.6a", "HasV8_6aOps", "true",
|
||||
"Support ARM v8.6a instructions",
|
||||
[HasV8_5aOps, FeatureBF16]>;
|
||||
[HasV8_5aOps, FeatureBF16,
|
||||
FeatureMatMulInt8]>;
|
||||
|
||||
def HasV8_1MMainlineOps : SubtargetFeature<
|
||||
"v8.1m.main", "HasV8_1MMainlineOps", "true",
|
||||
|
@ -4823,10 +4823,10 @@ def : Pat<(v4f32 (fma (fneg QPR:$Vn), QPR:$Vm, QPR:$src1)),
|
||||
// We put them in the VFPV8 decoder namespace because the ARM and Thumb
|
||||
// encodings are the same and thus no further bit twiddling is necessary
|
||||
// in the disassembler.
|
||||
class VDOT<bit op6, bit op4, RegisterClass RegTy, string Asm, string AsmTy,
|
||||
ValueType AccumTy, ValueType InputTy,
|
||||
class VDOT<bit op6, bit op4, bit op23, RegisterClass RegTy, string Asm,
|
||||
string AsmTy, ValueType AccumTy, ValueType InputTy,
|
||||
SDPatternOperator OpNode> :
|
||||
N3Vnp<0b11000, 0b10, 0b1101, op6, op4, (outs RegTy:$dst),
|
||||
N3Vnp<{0b1100, op23}, 0b10, 0b1101, op6, op4, (outs RegTy:$dst),
|
||||
(ins RegTy:$Vd, RegTy:$Vn, RegTy:$Vm), N3RegFrm, IIC_VDOTPROD,
|
||||
Asm, AsmTy,
|
||||
[(set (AccumTy RegTy:$dst),
|
||||
@ -4838,10 +4838,19 @@ class VDOT<bit op6, bit op4, RegisterClass RegTy, string Asm, string AsmTy,
|
||||
let Constraints = "$dst = $Vd";
|
||||
}
|
||||
|
||||
def VUDOTD : VDOT<0, 1, DPR, "vudot", "u8", v2i32, v8i8, int_arm_neon_udot>;
|
||||
def VSDOTD : VDOT<0, 0, DPR, "vsdot", "s8", v2i32, v8i8, int_arm_neon_sdot>;
|
||||
def VUDOTQ : VDOT<1, 1, QPR, "vudot", "u8", v4i32, v16i8, int_arm_neon_udot>;
|
||||
def VSDOTQ : VDOT<1, 0, QPR, "vsdot", "s8", v4i32, v16i8, int_arm_neon_sdot>;
|
||||
|
||||
class VUSDOT<bit op6, bit op4, bit op23, RegisterClass RegTy, string Asm,
|
||||
string AsmTy, ValueType AccumTy, ValueType InputTy,
|
||||
SDPatternOperator OpNode> :
|
||||
VDOT<op6, op4, op23, RegTy, Asm, AsmTy, AccumTy, InputTy, OpNode> {
|
||||
let hasNoSchedulingInfo = 1;
|
||||
|
||||
}
|
||||
|
||||
def VUDOTD : VDOT<0, 1, 0, DPR, "vudot", "u8", v2i32, v8i8, int_arm_neon_udot>;
|
||||
def VSDOTD : VDOT<0, 0, 0, DPR, "vsdot", "s8", v2i32, v8i8, int_arm_neon_sdot>;
|
||||
def VUDOTQ : VDOT<1, 1, 0, QPR, "vudot", "u8", v4i32, v16i8, int_arm_neon_udot>;
|
||||
def VSDOTQ : VDOT<1, 0, 0, QPR, "vsdot", "s8", v4i32, v16i8, int_arm_neon_sdot>;
|
||||
|
||||
// Indexed dot product instructions:
|
||||
multiclass DOTI<string opc, string dt, bit Q, bit U, RegisterClass Ty,
|
||||
@ -4876,6 +4885,70 @@ defm VUDOTQI : DOTI<"vudot", "u8", 0b1, 0b1, QPR, v4i32, v16i8,
|
||||
defm VSDOTQI : DOTI<"vsdot", "s8", 0b1, 0b0, QPR, v4i32, v16i8,
|
||||
int_arm_neon_sdot, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
|
||||
|
||||
// v8.6A matrix multiplication extension
|
||||
let Predicates = [HasMatMulInt8] in {
|
||||
class N3VMatMul<bit B, bit U, string Asm, string AsmTy,
|
||||
SDPatternOperator OpNode>
|
||||
: N3Vnp<{0b1100, B}, 0b10, 0b1100, 1, U, (outs QPR:$dst),
|
||||
(ins QPR:$Vd, QPR:$Vn, QPR:$Vm), N3RegFrm, NoItinerary,
|
||||
Asm, AsmTy,
|
||||
[(set (v4i32 QPR:$dst), (OpNode (v4i32 QPR:$Vd),
|
||||
(v16i8 QPR:$Vn),
|
||||
(v16i8 QPR:$Vm)))]> {
|
||||
let DecoderNamespace = "VFPV8";
|
||||
let Constraints = "$dst = $Vd";
|
||||
let hasNoSchedulingInfo = 1;
|
||||
}
|
||||
|
||||
multiclass N3VMixedDotLane<bit Q, bit U, string Asm, string AsmTy, RegisterClass RegTy,
|
||||
ValueType AccumTy, ValueType InputTy, SDPatternOperator OpNode,
|
||||
dag RHS> {
|
||||
|
||||
def "" : N3Vnp<0b11101, 0b00, 0b1101, Q, U, (outs RegTy:$dst),
|
||||
(ins RegTy:$Vd, RegTy:$Vn, DPR_VFP2:$Vm, VectorIndex32:$lane), N3RegFrm,
|
||||
NoItinerary, Asm, AsmTy, []> {
|
||||
bit lane;
|
||||
let hasNoSchedulingInfo = 1;
|
||||
let Inst{5} = lane;
|
||||
let AsmString = !strconcat(Asm, ".", AsmTy, "\t$Vd, $Vn, $Vm$lane");
|
||||
let DecoderNamespace = "VFPV8";
|
||||
let Constraints = "$dst = $Vd";
|
||||
}
|
||||
|
||||
def : Pat<
|
||||
(AccumTy (OpNode (AccumTy RegTy:$Vd),
|
||||
(InputTy RegTy:$Vn),
|
||||
(InputTy (bitconvert (AccumTy
|
||||
(ARMvduplane (AccumTy RegTy:$Vm),
|
||||
VectorIndex32:$lane)))))),
|
||||
(!cast<Instruction>(NAME) RegTy:$Vd, RegTy:$Vn, RHS, VectorIndex32:$lane)>;
|
||||
|
||||
}
|
||||
|
||||
multiclass SUDOTLane<bit Q, RegisterClass RegTy, ValueType AccumTy, ValueType InputTy, dag RHS>
|
||||
: N3VMixedDotLane<Q, 1, "vsudot", "u8", RegTy, AccumTy, InputTy, null_frag, null_frag> {
|
||||
def : Pat<
|
||||
(AccumTy (int_arm_neon_usdot (AccumTy RegTy:$Vd),
|
||||
(InputTy (bitconvert (AccumTy
|
||||
(ARMvduplane (AccumTy RegTy:$Vm),
|
||||
VectorIndex32:$lane)))),
|
||||
(InputTy RegTy:$Vn))),
|
||||
(!cast<Instruction>(NAME) RegTy:$Vd, RegTy:$Vn, RHS, VectorIndex32:$lane)>;
|
||||
}
|
||||
|
||||
def VSMMLA : N3VMatMul<0, 0, "vsmmla", "s8", int_arm_neon_smmla>;
|
||||
def VUMMLA : N3VMatMul<0, 1, "vummla", "u8", int_arm_neon_ummla>;
|
||||
def VUSMMLA : N3VMatMul<1, 0, "vusmmla", "s8", int_arm_neon_usmmla>;
|
||||
def VUSDOTD : VUSDOT<0, 0, 1, DPR, "vusdot", "s8", v2i32, v8i8, int_arm_neon_usdot>;
|
||||
def VUSDOTQ : VUSDOT<1, 0, 1, QPR, "vusdot", "s8", v4i32, v16i8, int_arm_neon_usdot>;
|
||||
|
||||
defm VUSDOTDI : N3VMixedDotLane<0, 0, "vusdot", "s8", DPR, v2i32, v8i8,
|
||||
int_arm_neon_usdot, (v2i32 DPR_VFP2:$Vm)>;
|
||||
defm VUSDOTQI : N3VMixedDotLane<1, 0, "vusdot", "s8", QPR, v4i32, v16i8,
|
||||
int_arm_neon_usdot, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
|
||||
defm VSUDOTDI : SUDOTLane<0, DPR, v2i32, v8i8, (v2i32 DPR_VFP2:$Vm)>;
|
||||
defm VSUDOTQI : SUDOTLane<1, QPR, v4i32, v16i8, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
|
||||
}
|
||||
|
||||
// ARMv8.3 complex operations
|
||||
class BaseN3VCP8ComplexTied<bit op21, bit op4, bit s, bit q,
|
||||
|
@ -110,6 +110,8 @@ def HasFP16FML : Predicate<"Subtarget->hasFP16FML()">,
|
||||
AssemblerPredicate<(all_of FeatureFP16FML),"full half-float fml">;
|
||||
def HasBF16 : Predicate<"Subtarget->hasBF16()">,
|
||||
AssemblerPredicate<(all_of FeatureBF16),"BFloat16 floating point extension">;
|
||||
def HasMatMulInt8 : Predicate<"Subtarget->hasMatMulInt8()">,
|
||||
AssemblerPredicate<(all_of FeatureMatMulInt8),"8-bit integer matrix multiply">;
|
||||
def HasDivideInThumb : Predicate<"Subtarget->hasDivideInThumbMode()">,
|
||||
AssemblerPredicate<(all_of FeatureHWDivThumb), "divide in THUMB">;
|
||||
def HasDivideInARM : Predicate<"Subtarget->hasDivideInARMMode()">,
|
||||
|
@ -260,6 +260,9 @@ protected:
|
||||
/// HasBF16 - True if subtarget supports BFloat16 floating point operations
|
||||
bool HasBF16 = false;
|
||||
|
||||
/// HasMatMulInt8 - True if subtarget supports 8-bit integer matrix multiply
|
||||
bool HasMatMulInt8 = false;
|
||||
|
||||
/// HasD32 - True if subtarget has the full 32 double precision
|
||||
/// FP registers for VFPv3.
|
||||
bool HasD32 = false;
|
||||
@ -704,6 +707,8 @@ public:
|
||||
/// Return true if the CPU supports any kind of instruction fusion.
|
||||
bool hasFusion() const { return hasFuseAES() || hasFuseLiterals(); }
|
||||
|
||||
bool hasMatMulInt8() const { return HasMatMulInt8; }
|
||||
|
||||
const Triple &getTargetTriple() const { return TargetTriple; }
|
||||
|
||||
bool isTargetDarwin() const { return TargetTriple.isOSDarwin(); }
|
||||
|
83
test/CodeGen/ARM/arm-matmul.ll
Normal file
83
test/CodeGen/ARM/arm-matmul.ll
Normal file
@ -0,0 +1,83 @@
|
||||
; RUN: llc -mtriple=arm-none-linux-gnu -mattr=+neon,+i8mm -float-abi=hard < %s -o -| FileCheck %s
|
||||
|
||||
define <4 x i32> @smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: smmla.v4i32.v16i8
|
||||
; CHECK: vsmmla.s8 q0, q1, q2
|
||||
%vmmla1.i = tail call <4 x i32> @llvm.arm.neon.smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
|
||||
ret <4 x i32> %vmmla1.i
|
||||
}
|
||||
|
||||
define <4 x i32> @ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: ummla.v4i32.v16i8
|
||||
; CHECK: vummla.u8 q0, q1, q2
|
||||
%vmmla1.i = tail call <4 x i32> @llvm.arm.neon.ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
|
||||
ret <4 x i32> %vmmla1.i
|
||||
}
|
||||
|
||||
define <4 x i32> @usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: usmmla.v4i32.v16i8
|
||||
; CHECK: vusmmla.s8 q0, q1, q2
|
||||
%vusmmla1.i = tail call <4 x i32> @llvm.arm.neon.usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
|
||||
ret <4 x i32> %vusmmla1.i
|
||||
}
|
||||
|
||||
define <2 x i32> @usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: usdot.v2i32.v8i8
|
||||
; CHECK: vusdot.s8 d0, d1, d2
|
||||
%vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) #3
|
||||
ret <2 x i32> %vusdot1.i
|
||||
}
|
||||
|
||||
define <2 x i32> @usdot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: usdot_lane.v2i32.v8i8
|
||||
; CHECK: vusdot.s8 d0, d1, d2[0]
|
||||
%0 = bitcast <8 x i8> %b to <2 x i32>
|
||||
%shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
|
||||
%1 = bitcast <2 x i32> %shuffle to <8 x i8>
|
||||
%vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %1) #3
|
||||
ret <2 x i32> %vusdot1.i
|
||||
}
|
||||
|
||||
define <2 x i32> @sudot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: sudot_lane.v2i32.v8i8
|
||||
; CHECK: vsudot.u8 d0, d1, d2[0]
|
||||
%0 = bitcast <8 x i8> %b to <2 x i32>
|
||||
%shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
|
||||
%1 = bitcast <2 x i32> %shuffle to <8 x i8>
|
||||
%vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %1, <8 x i8> %a) #3
|
||||
ret <2 x i32> %vusdot1.i
|
||||
}
|
||||
|
||||
define <4 x i32> @usdotq_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: usdotq_lane.v4i32.v16i8
|
||||
; CHECK: vusdot.s8 q0, q1, d4[0]
|
||||
%0 = bitcast <8 x i8> %b to <2 x i32>
|
||||
%shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
|
||||
%1 = bitcast <4 x i32> %shuffle to <16 x i8>
|
||||
%vusdot1.i = tail call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %1) #3
|
||||
ret <4 x i32> %vusdot1.i
|
||||
}
|
||||
|
||||
define <4 x i32> @sudotq_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
|
||||
entry:
|
||||
; CHECK-LABEL: sudotq_lane.v4i32.v16i8
|
||||
; CHECK: vsudot.u8 q0, q1, d4[0]
|
||||
%0 = bitcast <8 x i8> %b to <2 x i32>
|
||||
%shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
|
||||
%1 = bitcast <4 x i32> %shuffle to <16 x i8>
|
||||
%vusdot1.i = tail call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %1, <16 x i8> %a) #3
|
||||
ret <4 x i32> %vusdot1.i
|
||||
}
|
||||
|
||||
declare <4 x i32> @llvm.arm.neon.smmla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
|
||||
declare <4 x i32> @llvm.arm.neon.ummla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
|
||||
declare <4 x i32> @llvm.arm.neon.usmmla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
|
||||
declare <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32>, <8 x i8>, <8 x i8>) #2
|
||||
declare <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
|
Loading…
x
Reference in New Issue
Block a user