From 7fc239ff4e4e4c678be41c54cbb3beb8f9260781 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 15 Jun 2020 15:54:12 +0100 Subject: [PATCH] [X86][SSE] Add LowerVectorAllZero helper for checking if all bits of a vector are zero. Pull the lowering code out of LowerVectorAllZeroTest (and rename it MatchVectorAllZeroTest). We should be able to reuse this in combineVectorSizedSetCCEquality as well. Another cleanup to simplify D81547. --- lib/Target/X86/X86ISelLowering.cpp | 99 +++++++++++++++++------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index b80c94661d7..c8576473180 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -21346,55 +21346,26 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp, return true; } -// Check whether an OR'd tree is PTEST-able, or if we can fallback to -// CMP(MOVMSK(PCMPEQB(X,0))). -static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, - const X86Subtarget &Subtarget, - SelectionDAG &DAG, SDValue &X86CC) { - assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree."); - - if (!Subtarget.hasSSE2() || !Op->hasOneUse()) - return SDValue(); - - SmallVector VecIns; - if (!matchScalarReduction(Op, ISD::OR, VecIns)) - return SDValue(); - - assert(llvm::all_of(VecIns, - [VecIns](SDValue V) { - return VecIns[0].getValueType() == V.getValueType(); - }) && - "Reduction source vector mismatch"); +// Helper function for comparing all bits of a vector against zero. +static SDValue LowerVectorAllZero(const SDLoc &DL, SDValue V, ISD::CondCode CC, + const X86Subtarget &Subtarget, + SelectionDAG &DAG, X86::CondCode &X86CC) { + EVT VT = V.getValueType(); // Quit if less than 128-bits or not splittable to 128/256-bit vector. - EVT VT = VecIns[0].getValueType(); if (VT.getSizeInBits() < 128 || !isPowerOf2_32(VT.getSizeInBits())) return SDValue(); - SDLoc DL(Op); - - // If more than one full vector is evaluated, OR them first before PTEST. - for (unsigned Slot = 0, e = VecIns.size(); e - Slot > 1; Slot += 2, e += 1) { - // Each iteration will OR 2 nodes and append the result until there is only - // 1 node left, i.e. the final OR'd value of all vectors. - SDValue LHS = VecIns[Slot]; - SDValue RHS = VecIns[Slot + 1]; - VecIns.push_back(DAG.getNode(ISD::OR, DL, VT, LHS, RHS)); - } - - SDValue V = VecIns.back(); + X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE); // Split down to 128/256-bit vector. - unsigned TestSize = Subtarget.hasAVX()? 256 : 128; + unsigned TestSize = Subtarget.hasAVX() ? 256 : 128; while (VT.getSizeInBits() > TestSize) { auto Split = DAG.SplitVector(V, DL); VT = Split.first.getValueType(); V = DAG.getNode(ISD::OR, DL, VT, Split.first, Split.second); } - X86CC = DAG.getTargetConstant(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, - DL, MVT::i8); - bool UsePTEST = Subtarget.hasSSE41(); if (UsePTEST) { MVT TestVT = VT.is128BitVector() ? MVT::v2i64 : MVT::v4i64; @@ -21402,14 +21373,58 @@ static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V); } - SDValue Result = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, - DAG.getBitcast(MVT::v16i8, V), - getZeroVector(MVT::v16i8, Subtarget, DAG, DL)); - Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result); - return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, + V = DAG.getBitcast(MVT::v16i8, V); + V = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, V, + getZeroVector(MVT::v16i8, Subtarget, DAG, DL)); + V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V); + return DAG.getNode(X86ISD::CMP, DL, MVT::i32, V, DAG.getConstant(0xFFFF, DL, MVT::i32)); } +// Check whether an OR'd tree is PTEST-able, or if we can fallback to +// CMP(MOVMSK(PCMPEQB(X,0))). +static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC, + const X86Subtarget &Subtarget, + SelectionDAG &DAG, SDValue &X86CC) { + assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree."); + + if (!Subtarget.hasSSE2() || !Op->hasOneUse()) + return SDValue(); + + SmallVector VecIns; + if (matchScalarReduction(Op, ISD::OR, VecIns)) { + EVT VT = VecIns[0].getValueType(); + assert(llvm::all_of(VecIns, + [VT](SDValue V) { return VT == V.getValueType(); }) && + "Reduction source vector mismatch"); + + // Quit if less than 128-bits or not splittable to 128/256-bit vector. + if (VT.getSizeInBits() < 128 || !isPowerOf2_32(VT.getSizeInBits())) + return SDValue(); + + SDLoc DL(Op); + + // If more than one full vector is evaluated, OR them first before PTEST. + for (unsigned Slot = 0, e = VecIns.size(); e - Slot > 1; + Slot += 2, e += 1) { + // Each iteration will OR 2 nodes and append the result until there is + // only 1 node left, i.e. the final OR'd value of all vectors. + SDValue LHS = VecIns[Slot]; + SDValue RHS = VecIns[Slot + 1]; + VecIns.push_back(DAG.getNode(ISD::OR, DL, VT, LHS, RHS)); + } + + X86::CondCode CCode; + if (SDValue V = + LowerVectorAllZero(DL, VecIns.back(), CC, Subtarget, DAG, CCode)) { + X86CC = DAG.getTargetConstant(CCode, DL, MVT::i8); + return V; + } + } + + return SDValue(); +} + /// return true if \c Op has a use that doesn't just read flags. static bool hasNonFlagsUse(SDValue Op) { for (SDNode::use_iterator UI = Op->use_begin(), UE = Op->use_end(); UI != UE; @@ -22559,7 +22574,7 @@ SDValue X86TargetLowering::emitFlagsForSetcc(SDValue Op0, SDValue Op1, // TODO: We could do AND tree with all 1s as well by using the C flag. if (Op0.getOpcode() == ISD::OR && isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - if (SDValue CmpZ = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG, X86CC)) + if (SDValue CmpZ = MatchVectorAllZeroTest(Op0, CC, Subtarget, DAG, X86CC)) return CmpZ; }