diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 226b7f9b9df..0387cc2fa5d 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -40276,6 +40276,38 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } + // Detect pattern for AVX2+ variable shifts (shl, lshr) for inf precision. + if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC && + SupportedVectorVarShift(VT.getSimpleVT(), Subtarget, ISD::SHL)) { + ISD::CondCode CC = cast(Cond.getOperand(2))->get(); + + // Check if one of the arms of the VSELECT is a zero vector. If it's on the + // left side invert the predicate to simplify logic below. + SDValue Other; + if (ISD::isBuildVectorAllZeros(LHS.getNode())) { + Other = RHS; + CC = ISD::getSetCCInverse(CC, VT.getVectorElementType()); + } else if (ISD::isBuildVectorAllZeros(RHS.getNode())) { + Other = LHS; + } + + // Look for the following patterns (>> becomes vsrlv): + // y < 32 ? x << y : 0 --> vshlv(x, y) + // y <= 31 ? x << y : 0 --> vshlv(x, y) + APInt CondRHS; + if (Other && Other.getNumOperands() == 2 && + DAG.isEqualTo(Other.getOperand(1), Cond.getOperand(0)) && + (Other.getOpcode() == ISD::SHL || Other.getOpcode() == ISD::SRL) && + ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), CondRHS)) { + + // Replace ISD::SHL or ISD::SHR with appropriate AVX2 vector-vector shift. + unsigned op = Other.getOpcode() == ISD::SHL ? X86ISD::VSHLV : X86ISD::VSRLV; + if ((CC == ISD::SETULT && CondRHS == VT.getScalarSizeInBits()) || + (CC == ISD::SETULE && CondRHS == VT.getScalarSizeInBits() - 1)) + return DAG.getNode(op, DL, VT, Other.getOperand(0), Other.getOperand(1)); + } + } + // Match VSELECTs into subs with unsigned saturation. if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC && // psubus is available in SSE2 for i8 and i16 vectors.