diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index 5df40671b52..0943fa1aff4 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1584,10 +1584,10 @@ public: bool isBigEndian = false) const; /// getConstantSplatValue - Check if this is a constant splat, and if so, - /// return the splat value only if it is a ConstantSDNode. Otherwise - /// return nullptr. This is a simpler form of isConstantSplat. - /// Get the constant splat only if you care about the splat value. - ConstantSDNode *getConstantSplatValue() const; + /// return the splatted value. Otherwise return a null SDValue. This is + /// a simpler form of isConstantSplat. Get the constant splat only if you + /// care about the splat value. + SDValue getConstantSplatValue() const; bool isConstant() const; diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 7198203036b..9a91dcc341b 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -654,13 +654,12 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; - if (BuildVectorSDNode *BV = dyn_cast(N)) { - ConstantSDNode *CN = BV->getConstantSplatValue(); - - // BuildVectors can truncate their operands. Ignore that case here. - if (CN && CN->getValueType(0) == N.getValueType().getScalarType()) - return CN; - } + if (BuildVectorSDNode *BV = dyn_cast(N)) + if (SDValue Splat = BV->getConstantSplatValue()) + if (auto *CN = dyn_cast(Splat)) + // BuildVectors can truncate their operands. Ignore that case here. + if (CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; return nullptr; } diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 3a8a5f9601f..fb7d1b18d8b 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6603,16 +6603,28 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, return true; } -ConstantSDNode *BuildVectorSDNode::getConstantSplatValue() const { - SDValue Op0 = getOperand(0); - if (Op0.getOpcode() != ISD::Constant) - return nullptr; +SDValue BuildVectorSDNode::getConstantSplatValue() const { + SDValue Splatted; + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + SDValue Op = getOperand(i); + if (Op.getOpcode() == ISD::UNDEF) + continue; + if (Op.getOpcode() != ISD::Constant && Op.getOpcode() != ISD::ConstantFP) + return SDValue(); - for (unsigned i = 1, e = getNumOperands(); i != e; ++i) - if (getOperand(i) != Op0) - return nullptr; + if (!Splatted) + Splatted = Op; + else if (Splatted != Op) + return SDValue(); + } - return cast(Op0); + if (!Splatted) { + assert(getOperand(0).getOpcode() == ISD::UNDEF && + "Can only have a splat without a constant for all undefs."); + return getOperand(0); + } + + return Splatted; } bool BuildVectorSDNode::isConstant() const { diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp index ad91d4a87c1..1b3e42848bf 100644 --- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1152,14 +1152,15 @@ bool TargetLowering::isConstTrueVal(const SDNode *N) const { bool IsVec = false; const ConstantSDNode *CN = dyn_cast(N); - if (!CN) { - const BuildVectorSDNode *BV = dyn_cast(N); - if (!BV) - return false; - - IsVec = true; - CN = BV->getConstantSplatValue(); - } + if (!CN) + if (auto *BV = dyn_cast(N)) + if (SDValue Splat = BV->getConstantSplatValue()) + if (auto *SplatCN = dyn_cast(Splat)) { + IsVec = true; + CN = SplatCN; + } + if (!CN) + return false; switch (getBooleanContents(IsVec)) { case UndefinedBooleanContent: @@ -1179,14 +1180,15 @@ bool TargetLowering::isConstFalseVal(const SDNode *N) const { bool IsVec = false; const ConstantSDNode *CN = dyn_cast(N); - if (!CN) { - const BuildVectorSDNode *BV = dyn_cast(N); - if (!BV) - return false; - - IsVec = true; - CN = BV->getConstantSplatValue(); - } + if (!CN) + if (auto *BV = dyn_cast(N)) + if (SDValue Splat = BV->getConstantSplatValue()) + if (auto *SplatCN = dyn_cast(Splat)) { + IsVec = true; + CN = SplatCN; + } + if (!CN) + return false; if (getBooleanContents(IsVec) == UndefinedBooleanContent) return !CN->getAPIntValue()[0]; diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 7488cece603..b372950d436 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -4858,19 +4858,6 @@ static bool ShouldXformToMOVLP(SDNode *V1, SDNode *V2, return true; } -/// isSplatVector - Returns true if N is a BUILD_VECTOR node whose elements are -/// all the same. -static bool isSplatVector(SDNode *N) { - if (N->getOpcode() != ISD::BUILD_VECTOR) - return false; - - SDValue SplatValue = N->getOperand(0); - for (unsigned i = 1, e = N->getNumOperands(); i != e; ++i) - if (N->getOperand(i) != SplatValue) - return false; - return true; -} - /// isZeroShuffle - Returns true if N is a VECTOR_SHUFFLE that can be resolved /// to an zero vector. /// FIXME: move to dag combiner / method on ShuffleVectorSDNode @@ -5779,17 +5766,20 @@ static SDValue LowerVectorBroadcast(SDValue Op, const X86Subtarget* Subtarget, return SDValue(); case ISD::BUILD_VECTOR: { + auto *BVOp = cast(Op.getNode()); // The BUILD_VECTOR node must be a splat. - if (!isSplatVector(Op.getNode())) + SDValue Splat = BVOp->getConstantSplatValue(); + if (!Splat) return SDValue(); - Ld = Op.getOperand(0); + Ld = Splat; ConstSplatVal = (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP); // The suspected load node has several users. Make sure that all // of its users are from the BUILD_VECTOR node. // Constants may have multiple users. + // FIXME: This doesn't make sense if the build vector contains undefs. if (!ConstSplatVal && !Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0)) return SDValue(); break; @@ -9375,8 +9365,12 @@ X86TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const { bool Commuted = false; // FIXME: This should also accept a bitcast of a splat? Be careful, not // 1,1,1,1 -> v8i16 though. - V1IsSplat = isSplatVector(V1.getNode()); - V2IsSplat = isSplatVector(V2.getNode()); + if (auto *BVOp = dyn_cast(V1.getNode())) + if (BVOp->getConstantSplatValue()) + V1IsSplat = true; + if (auto *BVOp = dyn_cast(V2.getNode())) + if (BVOp->getConstantSplatValue()) + V2IsSplat = true; // Canonicalize the splat or undef, if present, to be on the RHS. if (!V2IsUndef && V1IsSplat && !V2IsSplat) { @@ -15171,10 +15165,11 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, SDValue Amt = Op.getOperand(1); // Optimize shl/srl/sra with constant shift amount. - if (isSplatVector(Amt.getNode())) { - SDValue SclrAmt = Amt->getOperand(0); - if (ConstantSDNode *C = dyn_cast(SclrAmt)) { - uint64_t ShiftAmt = C->getZExtValue(); + if (auto *BVAmt = dyn_cast(Amt)) { + if (SDValue Splat = BVAmt->getConstantSplatValue()) { + uint64_t ShiftAmt = Splat.getOpcode() == ISD::UNDEF + ? 0 + : cast(Splat)->getZExtValue(); if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 || (Subtarget->hasInt256() && @@ -19423,27 +19418,35 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG, Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS)) return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS); - // If the RHS is a constant we have to reverse the const canonicalization. - // x > C-1 ? x+-C : 0 --> subus x, C - if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD && - isSplatVector(CondRHS.getNode()) && isSplatVector(OpRHS.getNode())) { - APInt A = cast(OpRHS.getOperand(0))->getAPIntValue(); - if (CondRHS.getConstantOperandVal(0) == -A-1) - return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, - DAG.getConstant(-A, VT)); - } + if (auto *OpRHSBV = dyn_cast(OpRHS)) { + SDValue OpRHSSplat = OpRHSBV->getConstantSplatValue(); + auto *OpRHSSplatConst = dyn_cast(OpRHSSplat); + if (auto *CondRHSBV = dyn_cast(CondRHS)) { + // If the RHS is a constant we have to reverse the const + // canonicalization. + // x > C-1 ? x+-C : 0 --> subus x, C + SDValue CondRHSSplat = CondRHSBV->getConstantSplatValue(); + auto *CondRHSSplatConst = dyn_cast(CondRHSSplat); + if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD && + CondRHSSplatConst && OpRHSSplatConst) { + APInt A = OpRHSSplatConst->getAPIntValue(); + if (CondRHSSplatConst->getAPIntValue() == -A - 1) + return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, + DAG.getConstant(-A, VT)); + } + } - // Another special case: If C was a sign bit, the sub has been - // canonicalized into a xor. - // FIXME: Would it be better to use computeKnownBits to determine whether - // it's safe to decanonicalize the xor? - // x s< 0 ? x^C : 0 --> subus x, C - if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR && - ISD::isBuildVectorAllZeros(CondRHS.getNode()) && - isSplatVector(OpRHS.getNode())) { - APInt A = cast(OpRHS.getOperand(0))->getAPIntValue(); - if (A.isSignBit()) - return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS); + // Another special case: If C was a sign bit, the sub has been + // canonicalized into a xor. + // FIXME: Would it be better to use computeKnownBits to determine + // whether it's safe to decanonicalize the xor? + // x s< 0 ? x^C : 0 --> subus x, C + if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR && + ISD::isBuildVectorAllZeros(CondRHS.getNode()) && OpRHSSplatConst) { + APInt A = OpRHSSplatConst->getAPIntValue(); + if (A.isSignBit()) + return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS); + } } } } @@ -20152,16 +20155,16 @@ static SDValue PerformSHLCombine(SDNode *N, SelectionDAG &DAG) { // vector operations in many cases. Also, on sandybridge ADD is faster than // shl. // (shl V, 1) -> add V,V - if (isSplatVector(N1.getNode())) { - assert(N0.getValueType().isVector() && "Invalid vector shift type"); - ConstantSDNode *N1C = dyn_cast(N1->getOperand(0)); - // We shift all of the values by one. In many cases we do not have - // hardware support for this operation. This is better expressed as an ADD - // of two values. - if (N1C && (1 == N1C->getZExtValue())) { - return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0); + if (auto *N1BV = dyn_cast(N1)) + if (SDValue N1Splat = N1BV->getConstantSplatValue()) { + assert(N0.getValueType().isVector() && "Invalid vector shift type"); + // We shift all of the values by one. In many cases we do not have + // hardware support for this operation. This is better expressed as an ADD + // of two values. + if (N1Splat.getOpcode() == ISD::Constant && + cast(N1Splat)->getZExtValue() == 1) + return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0); } - } return SDValue(); } @@ -20180,20 +20183,19 @@ static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG, SDValue Amt = N->getOperand(1); SDLoc DL(N); - if (isSplatVector(Amt.getNode())) { - SDValue SclrAmt = Amt->getOperand(0); - if (ConstantSDNode *C = dyn_cast(SclrAmt)) { - APInt ShiftAmt = C->getAPIntValue(); - unsigned MaxAmount = VT.getVectorElementType().getSizeInBits(); + if (auto *AmtBV = dyn_cast(Amt)) + if (SDValue AmtSplat = AmtBV->getConstantSplatValue()) + if (auto *AmtConst = dyn_cast(AmtSplat)) { + APInt ShiftAmt = AmtConst->getAPIntValue(); + unsigned MaxAmount = VT.getVectorElementType().getSizeInBits(); - // SSE2/AVX2 logical shifts always return a vector of 0s - // if the shift amount is bigger than or equal to - // the element size. The constant shift amount will be - // encoded as a 8-bit immediate. - if (ShiftAmt.trunc(8).uge(MaxAmount)) - return getZeroVector(VT, Subtarget, DAG, DL); - } - } + // SSE2/AVX2 logical shifts always return a vector of 0s + // if the shift amount is bigger than or equal to + // the element size. The constant shift amount will be + // encoded as a 8-bit immediate. + if (ShiftAmt.trunc(8).uge(MaxAmount)) + return getZeroVector(VT, Subtarget, DAG, DL); + } return SDValue(); } @@ -20387,9 +20389,10 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, // The right side has to be a 'trunc' or a constant vector. bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE; - bool RHSConst = (isSplatVector(N1.getNode()) && - isa(N1->getOperand(0))); - if (!RHSTrunc && !RHSConst) + SDValue RHSConstSplat; + if (auto *RHSBV = dyn_cast(N1)) + RHSConstSplat = RHSBV->getConstantSplatValue(); + if (!RHSTrunc && !RHSConstSplat) return SDValue(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -20399,9 +20402,9 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, // Set N0 and N1 to hold the inputs to the new wide operation. N0 = N0->getOperand(0); - if (RHSConst) { + if (RHSConstSplat) { N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getScalarType(), - N1->getOperand(0)); + RHSConstSplat); SmallVector C(WideVT.getVectorNumElements(), N1); N1 = DAG.getNode(ISD::BUILD_VECTOR, DL, WideVT, C); } else if (RHSTrunc) { @@ -20547,12 +20550,10 @@ static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG, unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits(); unsigned SraAmt = ~0; if (Mask.getOpcode() == ISD::SRA) { - SDValue Amt = Mask.getOperand(1); - if (isSplatVector(Amt.getNode())) { - SDValue SclrAmt = Amt->getOperand(0); - if (ConstantSDNode *C = dyn_cast(SclrAmt)) - SraAmt = C->getZExtValue(); - } + if (auto *AmtBV = dyn_cast(Mask.getOperand(1))) + if (SDValue AmtSplat = AmtBV->getConstantSplatValue()) + if (auto *AmtConst = dyn_cast(AmtSplat)) + SraAmt = AmtConst->getZExtValue(); } else if (Mask.getOpcode() == X86ISD::VSRAI) { SDValue SraC = Mask.getOperand(1); SraAmt = cast(SraC)->getZExtValue(); diff --git a/test/CodeGen/X86/vector-gep.ll b/test/CodeGen/X86/vector-gep.ll index 9c68f44dffb..61edb1e7a15 100644 --- a/test/CodeGen/X86/vector-gep.ll +++ b/test/CodeGen/X86/vector-gep.ll @@ -5,7 +5,7 @@ define <4 x i32*> @AGEP0(i32* %ptr) nounwind { entry: ;CHECK-LABEL: AGEP0 -;CHECK: vbroadcast +;CHECK: vpshufd {{.*}} # xmm0 = mem[0,0,0,0] ;CHECK-NEXT: vpaddd ;CHECK-NEXT: ret %vecinit.i = insertelement <4 x i32*> undef, i32* %ptr, i32 0 diff --git a/test/CodeGen/X86/widen_cast-4.ll b/test/CodeGen/X86/widen_cast-4.ll index 1bc06a77cbf..03dffc4842c 100644 --- a/test/CodeGen/X86/widen_cast-4.ll +++ b/test/CodeGen/X86/widen_cast-4.ll @@ -1,8 +1,9 @@ ; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s -; CHECK: psraw -; CHECK: psraw +; RUN: llc < %s -march=x86 -mattr=+sse4.2 -x86-experimental-vector-widening-legalization | FileCheck %s --check-prefix=CHECK-WIDE define void @update(i64* %dst_i, i64* %src_i, i32 %n) nounwind { +; CHECK-LABEL: update: +; CHECK-WIDE-LABEL: update: entry: %dst_i.addr = alloca i64* ; [#uses=2] %src_i.addr = alloca i64* ; [#uses=2] @@ -23,6 +24,8 @@ forcond: ; preds = %forinc, %entry br i1 %cmp, label %forbody, label %afterfor forbody: ; preds = %forcond +; CHECK: %forbody +; CHECK-WIDE: %forbody %tmp2 = load i32* %i ; [#uses=1] %tmp3 = load i64** %dst_i.addr ; [#uses=1] %arrayidx = getelementptr i64* %tmp3, i32 %tmp2 ; [#uses=1] @@ -44,6 +47,24 @@ forbody: ; preds = %forcond %shr = ashr <8 x i8> %add, < i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2 > ; <<8 x i8>> [#uses=1] store <8 x i8> %shr, <8 x i8>* %arrayidx10 br label %forinc +; CHECK: pmovzxbw +; CHECK-NEXT: paddw +; CHECK-NEXT: psllw $8 +; CHECK-NEXT: psraw $8 +; CHECK-NEXT: psraw $2 +; CHECK-NEXT: pshufb +; CHECK-NEXT: movlpd +; +; FIXME: We shouldn't require both a movd and an insert. +; CHECK-WIDE: movd +; CHECK-WIDE-NEXT: pinsrd +; CHECK-WIDE-NEXT: paddb +; CHECK-WIDE-NEXT: psrlw $2 +; CHECK-WIDE-NEXT: pand +; CHECK-WIDE-NEXT: pxor +; CHECK-WIDE-NEXT: psubb +; CHECK-WIDE-NEXT: pextrd +; CHECK-WIDE-NEXT: movd forinc: ; preds = %forbody %tmp15 = load i32* %i ; [#uses=1]