From 059fdca4daa51a8530d3edb9cf5dea28a4a64d85 Mon Sep 17 00:00:00 2001 From: David Green Date: Tue, 23 Feb 2021 20:31:01 +0000 Subject: [PATCH] [AArch64] Introduce UDOT/SDOT DAG nodes This is used to lower UDOT/SDOT instructions, as opposed to relying on the intrinsic. Subsequent optimizations will be able to optimize them more cleanly based on these nodes. --- lib/Target/AArch64/AArch64ISelLowering.cpp | 27 +++++++++++++--------- lib/Target/AArch64/AArch64ISelLowering.h | 4 ++++ lib/Target/AArch64/AArch64InstrInfo.td | 13 +++++++---- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index 036932d90f7..f3ca6881080 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1842,6 +1842,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::URHADD) MAKE_CASE(AArch64ISD::SHADD) MAKE_CASE(AArch64ISD::UHADD) + MAKE_CASE(AArch64ISD::SDOT) + MAKE_CASE(AArch64ISD::UDOT) MAKE_CASE(AArch64ISD::SMINV) MAKE_CASE(AArch64ISD::UMINV) MAKE_CASE(AArch64ISD::SMAXV) @@ -3860,14 +3862,19 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } - + case Intrinsic::aarch64_neon_sabd: case Intrinsic::aarch64_neon_uabd: { - return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(), - Op.getOperand(1), Op.getOperand(2)); + unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD + : AArch64ISD::SABD; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), + Op.getOperand(2)); } - case Intrinsic::aarch64_neon_sabd: { - return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(), - Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_neon_sdot: + case Intrinsic::aarch64_neon_udot: { + unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT + : AArch64ISD::SDOT; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), + Op.getOperand(2), Op.getOperand(3)); } } } @@ -11753,11 +11760,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, SDLoc DL(Op0); SDValue Ones = DAG.getConstant(1, DL, Op0VT); SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); - auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND) - ? Intrinsic::aarch64_neon_udot - : Intrinsic::aarch64_neon_sdot; - SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(), - DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros, + auto DotOpcode = + (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT; + SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Ones, Op0.getOperand(0)); return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } diff --git a/lib/Target/AArch64/AArch64ISelLowering.h b/lib/Target/AArch64/AArch64ISelLowering.h index 94aef30b21b..4959c8c3d58 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.h +++ b/lib/Target/AArch64/AArch64ISelLowering.h @@ -231,6 +231,10 @@ enum NodeType : unsigned { UABD, SABD, + // udot/sdot instructions + UDOT, + SDOT, + // Vector across-lanes min/max // Only the lower result lane is defined. SMINV, diff --git a/lib/Target/AArch64/AArch64InstrInfo.td b/lib/Target/AArch64/AArch64InstrInfo.td index 1dd2fb30b23..d4871c31127 100644 --- a/lib/Target/AArch64/AArch64InstrInfo.td +++ b/lib/Target/AArch64/AArch64InstrInfo.td @@ -247,6 +247,8 @@ def SDT_AArch64UnaryVec: SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; def SDT_AArch64ExtVec: SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>, SDTCisSameAs<0,2>, SDTCisInt<3>]>; def SDT_AArch64vshift : SDTypeProfile<1, 2, [SDTCisSameAs<0,1>, SDTCisInt<2>]>; +def SDT_AArch64Dot: SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>, + SDTCisVec<2>, SDTCisSameAs<2,3>]>; def SDT_AArch64vshiftinsert : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<3>, SDTCisSameAs<0,1>, @@ -561,6 +563,9 @@ def AArch64frecps : SDNode<"AArch64ISD::FRECPS", SDTFPBinOp>; def AArch64frsqrte : SDNode<"AArch64ISD::FRSQRTE", SDTFPUnaryOp>; def AArch64frsqrts : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>; +def AArch64sdot : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>; +def AArch64udot : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>; + def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>; def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>; def AArch64sminv : SDNode<"AArch64ISD::SMINV", SDT_AArch64UnaryVec>; @@ -831,10 +836,10 @@ def : TokenAlias<"IALL", "iall">; // ARMv8.2-A Dot Product let Predicates = [HasDotProd] in { -defm SDOT : SIMDThreeSameVectorDot<0, 0, "sdot", int_aarch64_neon_sdot>; -defm UDOT : SIMDThreeSameVectorDot<1, 0, "udot", int_aarch64_neon_udot>; -defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", int_aarch64_neon_sdot>; -defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", int_aarch64_neon_udot>; +defm SDOT : SIMDThreeSameVectorDot<0, 0, "sdot", AArch64sdot>; +defm UDOT : SIMDThreeSameVectorDot<1, 0, "udot", AArch64udot>; +defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>; +defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>; } // ARMv8.6-A BFloat