diff --git a/include/llvm/IR/IntrinsicsAArch64.td b/include/llvm/IR/IntrinsicsAArch64.td index 9019b9d3be5..483afe97cc6 100644 --- a/include/llvm/IR/IntrinsicsAArch64.td +++ b/include/llvm/IR/IntrinsicsAArch64.td @@ -178,6 +178,12 @@ let TargetPrefix = "aarch64" in { // All intrinsics start with "llvm.aarch64.". : Intrinsic<[llvm_anyvector_ty], [LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>], [IntrNoMem]>; + + class AdvSIMD_FML_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>], + [IntrNoMem]>; + } // Arithmetic ops @@ -459,6 +465,11 @@ let TargetPrefix = "aarch64", IntrProperties = [IntrNoMem] in { def int_aarch64_neon_smmla : AdvSIMD_MatMul_Intrinsic; def int_aarch64_neon_usmmla : AdvSIMD_MatMul_Intrinsic; def int_aarch64_neon_usdot : AdvSIMD_Dot_Intrinsic; + def int_aarch64_neon_bfdot : AdvSIMD_Dot_Intrinsic; + def int_aarch64_neon_bfmmla : AdvSIMD_MatMul_Intrinsic; + def int_aarch64_neon_bfmlalb : AdvSIMD_FML_Intrinsic; + def int_aarch64_neon_bfmlalt : AdvSIMD_FML_Intrinsic; + // v8.2-A FP16 Fused Multiply-Add Long def int_aarch64_neon_fmlal : AdvSIMD_FP16FML_Intrinsic; diff --git a/lib/Target/AArch64/AArch64InstrFormats.td b/lib/Target/AArch64/AArch64InstrFormats.td index 713bf0bf3ca..8f5202af96e 100644 --- a/lib/Target/AArch64/AArch64InstrFormats.td +++ b/lib/Target/AArch64/AArch64InstrFormats.td @@ -7815,16 +7815,19 @@ let mayStore = 0, mayLoad = 0, hasSideEffects = 0 in { class BaseSIMDThreeSameVectorBFDot - : BaseSIMDThreeSameVectorTied { + : BaseSIMDThreeSameVectorTied { let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}"); } multiclass SIMDThreeSameVectorBFDot { - def v4f16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64, + def v4bf16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64, v2f32, v8i8>; - def v8f16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128, + def v8bf16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128, v4f32, v16i8>; } @@ -7837,7 +7840,13 @@ class BaseSIMDThreeSameVectorBF16DotI { + [(set (AccumType RegType:$dst), + (AccumType (int_aarch64_neon_bfdot + (AccumType RegType:$Rd), + (InputType RegType:$Rn), + (InputType (bitconvert (AccumType + (AArch64duplane32 (v4f32 V128:$Rm), + VectorIndexH:$idx)))))))]> { bits<2> idx; let Inst{21} = idx{0}; // L @@ -7846,23 +7855,30 @@ class BaseSIMDThreeSameVectorBF16DotI { - def v4f16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h", + def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h", ".2h", V64, v2f32, v8i8>; - def v8f16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h", + def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h", ".2h", V128, v4f32, v16i8>; } -class SIMDBF16MLAL +class SIMDBF16MLAL : BaseSIMDThreeSameVectorTied { // TODO: Add intrinsics + [(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd), + (v16i8 V128:$Rn), + (v16i8 V128:$Rm)))]> { let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h}"); } -class SIMDBF16MLALIndex +class SIMDBF16MLALIndex : I<(outs V128:$dst), (ins V128:$Rd, V128:$Rn, V128_lo:$Rm, VectorIndexH:$idx), asm, "{\t$Rd.4s, $Rn.8h, $Rm.h$idx}", "$Rd = $dst", - []>, // TODO: Add intrinsics + [(set (v4f32 V128:$dst), + (v4f32 (OpNode (v4f32 V128:$Rd), + (v16i8 V128:$Rn), + (v16i8 (bitconvert (v8bf16 + (AArch64duplane16 (v8bf16 V128_lo:$Rm), + VectorIndexH:$idx)))))))]>, Sched<[WriteV]> { bits<5> Rd; bits<5> Rn; @@ -7884,7 +7900,10 @@ class SIMDBF16MLALIndex class SIMDThreeSameVectorBF16MatrixMul : BaseSIMDThreeSameVectorTied<1, 1, 0b010, 0b11101, V128, asm, ".4s", - []> { + [(set (v4f32 V128:$dst), + (int_aarch64_neon_bfmmla (v4f32 V128:$Rd), + (v16i8 V128:$Rn), + (v16i8 V128:$Rm)))]> { let AsmString = !strconcat(asm, "{\t$Rd", ".4s", ", $Rn", ".8h", ", $Rm", ".8h", "}"); } diff --git a/lib/Target/AArch64/AArch64InstrInfo.td b/lib/Target/AArch64/AArch64InstrInfo.td index 8716ffb412d..b56c5d9ff85 100644 --- a/lib/Target/AArch64/AArch64InstrInfo.td +++ b/lib/Target/AArch64/AArch64InstrInfo.td @@ -784,10 +784,10 @@ let Predicates = [HasBF16] in { defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">; defm BF16DOTlane : SIMDThreeSameVectorBF16DotI<0, "bfdot">; def BFMMLA : SIMDThreeSameVectorBF16MatrixMul<"bfmmla">; -def BFMLALB : SIMDBF16MLAL<0, "bfmlalb">; -def BFMLALT : SIMDBF16MLAL<1, "bfmlalt">; -def BFMLALBIdx : SIMDBF16MLALIndex<0, "bfmlalb">; -def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt">; +def BFMLALB : SIMDBF16MLAL<0, "bfmlalb", int_aarch64_neon_bfmlalb>; +def BFMLALT : SIMDBF16MLAL<1, "bfmlalt", int_aarch64_neon_bfmlalt>; +def BFMLALBIdx : SIMDBF16MLALIndex<0, "bfmlalb", int_aarch64_neon_bfmlalb>; +def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>; def BFCVTN : SIMD_BFCVTN; def BFCVTN2 : SIMD_BFCVTN2; def BFCVT : BF16ToSinglePrecision<"bfcvt">; diff --git a/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll new file mode 100644 index 00000000000..96513115f2d --- /dev/null +++ b/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll @@ -0,0 +1,176 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple aarch64-arm-none-eabi -mattr=+bf16 %s -o - | FileCheck %s + +define <2 x float> @test_vbfdot_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfdot_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfdot v0.2s, v1.4h, v2.4h +; CHECK-NEXT: ret +entry: + %0 = bitcast <4 x bfloat> %a to <8 x i8> + %1 = bitcast <4 x bfloat> %b to <8 x i8> + %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %0, <8 x i8> %1) + ret <2 x float> %vbfdot1.i +} + +define <4 x float> @test_vbfdotq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfdotq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfdot v0.4s, v1.8h, v2.8h +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %b to <16 x i8> + %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfdot1.i +} + +define <2 x float> @test_vbfdot_lane_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfdot_lane_f32: +; CHECK: // %bb.0: // %entry +; CHECK: bfdot v0.2s, v1.4h, v2.2h[0] +; CHECK-NEXT: ret +entry: + %0 = bitcast <4 x bfloat> %b to <2 x float> + %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <2 x i32> zeroinitializer + %1 = bitcast <4 x bfloat> %a to <8 x i8> + %2 = bitcast <2 x float> %shuffle to <8 x i8> + %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2) + ret <2 x float> %vbfdot1.i +} + +define <4 x float> @test_vbfdotq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfdotq_laneq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfdot v0.4s, v1.8h, v2.2h[3] +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %b to <4 x float> + %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <4 x i32> + %1 = bitcast <8 x bfloat> %a to <16 x i8> + %2 = bitcast <4 x float> %shuffle to <16 x i8> + %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2) + ret <4 x float> %vbfdot1.i +} + +define <2 x float> @test_vbfdot_laneq_f32(<2 x float> %r, <4 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfdot_laneq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfdot v0.2s, v1.4h, v2.2h[3] +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %b to <4 x float> + %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <2 x i32> + %1 = bitcast <4 x bfloat> %a to <8 x i8> + %2 = bitcast <2 x float> %shuffle to <8 x i8> + %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2) + ret <2 x float> %vbfdot1.i +} + +define <4 x float> @test_vbfdotq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfdotq_lane_f32: +; CHECK: // %bb.0: // %entry +; CHECK: bfdot v0.4s, v1.8h, v2.2h[0] +; CHECK-NEXT: ret +entry: + %0 = bitcast <4 x bfloat> %b to <2 x float> + %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <4 x i32> zeroinitializer + %1 = bitcast <8 x bfloat> %a to <16 x i8> + %2 = bitcast <4 x float> %shuffle to <16 x i8> + %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2) + ret <4 x float> %vbfdot1.i +} + +define <4 x float> @test_vbfmmlaq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfmmlaq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfmmla v0.4s, v1.8h, v2.8h +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %b to <16 x i8> + %vbfmmla1.i = tail call <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmmla1.i +} + +define <4 x float> @test_vbfmlalbq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlalbq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfmlalb v0.4s, v1.8h, v2.8h +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %b to <16 x i8> + %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalb1.i +} + +define <4 x float> @test_vbfmlaltq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlaltq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfmlalt v0.4s, v1.8h, v2.8h +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %b to <16 x i8> + %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalt1.i +} + +define <4 x float> @test_vbfmlalbq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlalbq_lane_f32: +; CHECK: // %bb.0: // %entry +; CHECK: bfmlalb v0.4s, v1.8h, v2.h[0] +; CHECK-NEXT: ret +entry: + %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8> + %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalb1.i +} + +define <4 x float> @test_vbfmlalbq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlalbq_laneq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfmlalb v0.4s, v1.8h, v2.h[3] +; CHECK-NEXT: ret +entry: + %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8> + %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalb1.i +} + +define <4 x float> @test_vbfmlaltq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlaltq_lane_f32: +; CHECK: // %bb.0: // %entry +; CHECK: bfmlalt v0.4s, v1.8h, v2.h[0] +; CHECK-NEXT: ret +entry: + %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8> + %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalt1.i +} + +define <4 x float> @test_vbfmlaltq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfmlaltq_laneq_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: bfmlalt v0.4s, v1.8h, v2.h[3] +; CHECK-NEXT: ret +entry: + %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> + %0 = bitcast <8 x bfloat> %a to <16 x i8> + %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8> + %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1) + ret <4 x float> %vbfmlalt1.i +} + +declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float>, <8 x i8>, <8 x i8>) #2 +declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2 +declare <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2 +declare <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2 +declare <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2