mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2025-02-01 05:01:59 +01:00
[ARM] Add predicated mla reduction patterns
Similar to 8fa824d7a3 but this time for MLA patterns, this selects predicated vmlav/vmlava/vmlalv/vmlava instructions from vecreduce.add(select(p, mul(x, y), 0)) nodes. Differential Revision: https://reviews.llvm.org/D84102
This commit is contained in:
parent
b1abd6ff52
commit
ab0cc8c927
@ -1730,10 +1730,16 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
|
||||
case ARMISD::VADDLVApu: return "ARMISD::VADDLVApu";
|
||||
case ARMISD::VMLAVs: return "ARMISD::VMLAVs";
|
||||
case ARMISD::VMLAVu: return "ARMISD::VMLAVu";
|
||||
case ARMISD::VMLAVps: return "ARMISD::VMLAVps";
|
||||
case ARMISD::VMLAVpu: return "ARMISD::VMLAVpu";
|
||||
case ARMISD::VMLALVs: return "ARMISD::VMLALVs";
|
||||
case ARMISD::VMLALVu: return "ARMISD::VMLALVu";
|
||||
case ARMISD::VMLALVps: return "ARMISD::VMLALVps";
|
||||
case ARMISD::VMLALVpu: return "ARMISD::VMLALVpu";
|
||||
case ARMISD::VMLALVAs: return "ARMISD::VMLALVAs";
|
||||
case ARMISD::VMLALVAu: return "ARMISD::VMLALVAu";
|
||||
case ARMISD::VMLALVAps: return "ARMISD::VMLALVAps";
|
||||
case ARMISD::VMLALVApu: return "ARMISD::VMLALVApu";
|
||||
case ARMISD::UMAAL: return "ARMISD::UMAAL";
|
||||
case ARMISD::UMLAL: return "ARMISD::UMLAL";
|
||||
case ARMISD::SMLAL: return "ARMISD::SMLAL";
|
||||
@ -12261,6 +12267,14 @@ static SDValue PerformADDVecReduce(SDNode *N,
|
||||
return M;
|
||||
if (SDValue M = MakeVecReduce(ARMISD::VMLALVu, ARMISD::VMLALVAu, N1, N0))
|
||||
return M;
|
||||
if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N0, N1))
|
||||
return M;
|
||||
if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N0, N1))
|
||||
return M;
|
||||
if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N1, N0))
|
||||
return M;
|
||||
if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N1, N0))
|
||||
return M;
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
@ -14760,6 +14774,26 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
auto IsPredVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes,
|
||||
SDValue &A, SDValue &B, SDValue &Mask) {
|
||||
if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT ||
|
||||
!ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode()))
|
||||
return false;
|
||||
Mask = N0->getOperand(0);
|
||||
SDValue Mul = N0->getOperand(1);
|
||||
if (Mul->getOpcode() != ISD::MUL)
|
||||
return false;
|
||||
SDValue ExtA = Mul->getOperand(0);
|
||||
SDValue ExtB = Mul->getOperand(1);
|
||||
if (ExtA->getOpcode() != ExtendCode && ExtB->getOpcode() != ExtendCode)
|
||||
return false;
|
||||
A = ExtA->getOperand(0);
|
||||
B = ExtB->getOperand(0);
|
||||
if (A.getValueType() == B.getValueType() &&
|
||||
llvm::any_of(ExtTypes, [&A](MVT Ty) { return A.getValueType() == Ty; }))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
auto Create64bitNode = [&](unsigned Opcode, ArrayRef<SDValue> Ops) {
|
||||
SDValue Node = DAG.getNode(Opcode, dl, {MVT::i32, MVT::i32}, Ops);
|
||||
return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Node,
|
||||
@ -14794,6 +14828,15 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
|
||||
return Create64bitNode(ARMISD::VMLALVs, {A, B});
|
||||
if (IsVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B))
|
||||
return Create64bitNode(ARMISD::VMLALVu, {A, B});
|
||||
|
||||
if (IsPredVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, Mask))
|
||||
return DAG.getNode(ARMISD::VMLAVps, dl, ResVT, A, B, Mask);
|
||||
if (IsPredVMLAV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, Mask))
|
||||
return DAG.getNode(ARMISD::VMLAVpu, dl, ResVT, A, B, Mask);
|
||||
if (IsPredVMLAV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, Mask))
|
||||
return Create64bitNode(ARMISD::VMLALVps, {A, B, Mask});
|
||||
if (IsPredVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, Mask))
|
||||
return Create64bitNode(ARMISD::VMLALVpu, {A, B, Mask});
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
|
@ -229,12 +229,18 @@ class VectorType;
|
||||
VADDLVpu,
|
||||
VADDLVAps, // Same as VADDLVp[su] but with a v4i1 predicate mask
|
||||
VADDLVApu,
|
||||
VMLAVs,
|
||||
VMLAVu,
|
||||
VMLALVs,
|
||||
VMLALVu,
|
||||
VMLALVAs,
|
||||
VMLALVAu,
|
||||
VMLAVs, // sign- or zero-extend the elements of two vectors to i32, multiply them
|
||||
VMLAVu, // and add the results together, returning an i32 of their sum
|
||||
VMLAVps, // Same as VMLAV[su] with a v4i1 predicate mask
|
||||
VMLAVpu,
|
||||
VMLALVs, // Same as VMLAV but with i64, returning the low and
|
||||
VMLALVu, // high 32-bit halves of the sum
|
||||
VMLALVps, // Same as VMLALV[su] with a v4i1 predicate mask
|
||||
VMLALVpu,
|
||||
VMLALVAs, // Same as VMLALV but also add an input accumulator
|
||||
VMLALVAu, // provided as low and high halves
|
||||
VMLALVAps, // Same as VMLALVA[su] with a v4i1 predicate mask
|
||||
VMLALVApu,
|
||||
|
||||
SMULWB, // Signed multiply word by half word, bottom
|
||||
SMULWT, // Signed multiply word by half word, top
|
||||
|
@ -1105,12 +1105,28 @@ def SDTVecReduce2LA : SDTypeProfile<2, 4, [ // VMLALVA
|
||||
SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>,
|
||||
SDTCisVec<4>, SDTCisVec<5>
|
||||
]>;
|
||||
def SDTVecReduce2P : SDTypeProfile<1, 3, [ // VMLAV
|
||||
SDTCisInt<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>
|
||||
]>;
|
||||
def SDTVecReduce2LP : SDTypeProfile<2, 3, [ // VMLALV
|
||||
SDTCisInt<0>, SDTCisInt<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCisVec<4>
|
||||
]>;
|
||||
def SDTVecReduce2LAP : SDTypeProfile<2, 5, [ // VMLALVA
|
||||
SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>,
|
||||
SDTCisVec<4>, SDTCisVec<5>, SDTCisVec<6>
|
||||
]>;
|
||||
def ARMVMLAVs : SDNode<"ARMISD::VMLAVs", SDTVecReduce2>;
|
||||
def ARMVMLAVu : SDNode<"ARMISD::VMLAVu", SDTVecReduce2>;
|
||||
def ARMVMLALVs : SDNode<"ARMISD::VMLALVs", SDTVecReduce2L>;
|
||||
def ARMVMLALVu : SDNode<"ARMISD::VMLALVu", SDTVecReduce2L>;
|
||||
def ARMVMLALVAs : SDNode<"ARMISD::VMLALVAs", SDTVecReduce2LA>;
|
||||
def ARMVMLALVAu : SDNode<"ARMISD::VMLALVAu", SDTVecReduce2LA>;
|
||||
def ARMVMLALVAs : SDNode<"ARMISD::VMLALVAs", SDTVecReduce2LA>;
|
||||
def ARMVMLALVAu : SDNode<"ARMISD::VMLALVAu", SDTVecReduce2LA>;
|
||||
def ARMVMLAVps : SDNode<"ARMISD::VMLAVps", SDTVecReduce2P>;
|
||||
def ARMVMLAVpu : SDNode<"ARMISD::VMLAVpu", SDTVecReduce2P>;
|
||||
def ARMVMLALVps : SDNode<"ARMISD::VMLALVps", SDTVecReduce2LP>;
|
||||
def ARMVMLALVpu : SDNode<"ARMISD::VMLALVpu", SDTVecReduce2LP>;
|
||||
def ARMVMLALVAps : SDNode<"ARMISD::VMLALVAps", SDTVecReduce2LAP>;
|
||||
def ARMVMLALVApu : SDNode<"ARMISD::VMLALVApu", SDTVecReduce2LAP>;
|
||||
|
||||
let Predicates = [HasMVEInt] in {
|
||||
def : Pat<(i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
|
||||
@ -1129,22 +1145,68 @@ let Predicates = [HasMVEInt] in {
|
||||
(i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
|
||||
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau32 $src3, $src1, $src2))>;
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau16 $src3, $src1, $src2))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVs (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVas16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVau16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau8 $src3, $src1, $src2))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVs (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVas8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVau8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
|
||||
|
||||
// Predicated
|
||||
def : Pat<(i32 (vecreduce_add (vselect (v4i1 VCCR:$pred),
|
||||
(mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)),
|
||||
(v4i32 ARMimmAllZerosV)))),
|
||||
(i32 (MVE_VMLADAVu32 $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (vecreduce_add (vselect (v8i1 VCCR:$pred),
|
||||
(mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)),
|
||||
(v8i16 ARMimmAllZerosV)))),
|
||||
(i32 (MVE_VMLADAVu16 $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (ARMVMLAVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred))),
|
||||
(i32 (MVE_VMLADAVs16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (ARMVMLAVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred))),
|
||||
(i32 (MVE_VMLADAVu16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (vecreduce_add (vselect (v16i1 VCCR:$pred),
|
||||
(mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)),
|
||||
(v16i8 ARMimmAllZerosV)))),
|
||||
(i32 (MVE_VMLADAVu8 $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (ARMVMLAVps (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred))),
|
||||
(i32 (MVE_VMLADAVs8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (ARMVMLAVpu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred))),
|
||||
(i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v4i1 VCCR:$pred),
|
||||
(mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)),
|
||||
(v4i32 ARMimmAllZerosV)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau32 $src3, $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v8i1 VCCR:$pred),
|
||||
(mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)),
|
||||
(v8i16 ARMimmAllZerosV)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau16 $src3, $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVas16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVau16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v16i1 VCCR:$pred),
|
||||
(mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)),
|
||||
(v16i8 ARMimmAllZerosV)))),
|
||||
(i32 tGPREven:$src3))),
|
||||
(i32 (MVE_VMLADAVau8 $src3, $src1, $src2, ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVps (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVas8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
def : Pat<(i32 (add (ARMVMLAVpu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred)), tGPREven:$Rd)),
|
||||
(i32 (MVE_VMLADAVau8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
|
||||
}
|
||||
|
||||
// vmlav aliases vmladav
|
||||
@ -1264,6 +1326,25 @@ let Predicates = [HasMVEInt] in {
|
||||
(MVE_VMLALDAVas16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))>;
|
||||
def : Pat<(ARMVMLALVAu tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)),
|
||||
(MVE_VMLALDAVau16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))>;
|
||||
|
||||
// Predicated
|
||||
def : Pat<(ARMVMLALVps (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVs32 (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVpu (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVu32 (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVs16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVu16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
|
||||
def : Pat<(ARMVMLALVAps tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVas32 tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVApu tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVau32 tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVAps tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVas16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
def : Pat<(ARMVMLALVApu tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
|
||||
(MVE_VMLALDAVau16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
|
||||
}
|
||||
|
||||
// vmlalv aliases vmlaldav
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user