diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 09a32d1c422..b955b9033fa 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -41659,6 +41659,22 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, return SDValue(); APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (N0.getOpcode() == ISD::SRA && N1.getOpcode() == ISD::SRA) { + // If both arguments are sign-extended, try to replace sign extends + // with zero extends, which should qualify for the optimization. + // Otherwise just fallback to zero-extension check. + if (isa(N0.getOperand(1).getOperand(0)) && + N0.getOperand(1).getConstantOperandVal(0) == 16 && + isa(N1.getOperand(1).getOperand(0)) && + N1.getOperand(1).getConstantOperandVal(0) == 16) { + // Nullify mask to pass the following check + Mask17 = 0; + N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), + N0.getOperand(1)); + N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0), + N1.getOperand(1)); + } + } if (!DAG.MaskedValueIsZero(N1, Mask17) || !DAG.MaskedValueIsZero(N0, Mask17)) return SDValue();