1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-23 11:13:28 +01:00

DAGCombiner: Turn divs of vector splats into vectorized multiplications.

Otherwise the legalizer would just scalarize everything. Support for
mulhi in the targets isn't that great yet so on most targets we get
exactly the same scalarized output. Add a test for x86 vector udiv.

I had to disable the mulhi nodes on ARM because there aren't any patterns
for it. As far as I know ARM has instructions for getting the high part of
a multiply so this should be fixed.

llvm-svn: 207315
This commit is contained in:
Benjamin Kramer 2014-04-26 12:06:28 +00:00
parent 46f8aa6183
commit 163df6bc62
6 changed files with 115 additions and 28 deletions

View File

@ -2417,10 +2417,12 @@ public:
//
SDValue BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
SelectionDAG &DAG) const;
SDValue BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode*> *Created) const;
SDValue BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode*> *Created) const;
SDValue BuildSDIV(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
bool IsAfterLegalization,
std::vector<SDNode *> *Created) const;
SDValue BuildUDIV(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
bool IsAfterLegalization,
std::vector<SDNode *> *Created) const;
//===--------------------------------------------------------------------===//
// Legalization utility functions

View File

@ -2024,7 +2024,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// if integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence.
if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap()) {
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
SDValue Op = BuildSDIV(N);
if (Op.getNode()) return Op;
}
@ -2076,7 +2076,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
}
}
// fold (udiv x, c) -> alternate
if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap()) {
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
SDValue Op = BuildUDIV(N);
if (Op.getNode()) return Op;
}
@ -11191,8 +11191,24 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildSDIV(SDNode *N) {
const APInt *Divisor;
if (N->getValueType(0).isVector()) {
// Handle splat vectors.
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
// Avoid division by zero.
if (!*Divisor)
return SDValue();
std::vector<SDNode*> Built;
SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, &Built);
SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
ii != ee; ++ii)
@ -11200,13 +11216,29 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
return S;
}
/// BuildUDIVSequence - Given an ISD::UDIV node expressing a divide by constant,
/// BuildUDIV - Given an ISD::UDIV node expressing a divide by constant,
/// return a DAG expression to select that will generate the same value by
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildUDIV(SDNode *N) {
const APInt *Divisor;
if (N->getValueType(0).isVector()) {
// Handle splat vectors.
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
// Avoid division by zero.
if (!*Divisor)
return SDValue();
std::vector<SDNode*> Built;
SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, &Built);
SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
ii != ee; ++ii)

View File

@ -2602,9 +2602,9 @@ SDValue TargetLowering::BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
/// return a DAG expression to select that will generate the same value by
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue TargetLowering::
BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode*> *Created) const {
SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor,
SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode *> *Created) const {
EVT VT = N->getValueType(0);
SDLoc dl(N);
@ -2613,8 +2613,7 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
if (!isTypeLegal(VT))
return SDValue();
APInt d = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
APInt::ms magics = d.magic();
APInt::ms magics = Divisor.magic();
// Multiply the numerator (operand 0) by the magic value
// FIXME: We should support doing a MUL in a wider type
@ -2631,13 +2630,13 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
else
return SDValue(); // No mulhs or equvialent
// If d > 0 and m < 0, add the numerator
if (d.isStrictlyPositive() && magics.m.isNegative()) {
if (Divisor.isStrictlyPositive() && magics.m.isNegative()) {
Q = DAG.getNode(ISD::ADD, dl, VT, Q, N->getOperand(0));
if (Created)
Created->push_back(Q.getNode());
}
// If d < 0 and m > 0, subtract the numerator.
if (d.isNegative() && magics.m.isStrictlyPositive()) {
if (Divisor.isNegative() && magics.m.isStrictlyPositive()) {
Q = DAG.getNode(ISD::SUB, dl, VT, Q, N->getOperand(0));
if (Created)
Created->push_back(Q.getNode());
@ -2650,9 +2649,9 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
Created->push_back(Q.getNode());
}
// Extract the sign bit and add it to the quotient
SDValue T =
DAG.getNode(ISD::SRL, dl, VT, Q, DAG.getConstant(VT.getSizeInBits()-1,
getShiftAmountTy(Q.getValueType())));
SDValue T = DAG.getNode(ISD::SRL, dl, VT, Q,
DAG.getConstant(VT.getScalarSizeInBits() - 1,
getShiftAmountTy(Q.getValueType())));
if (Created)
Created->push_back(T.getNode());
return DAG.getNode(ISD::ADD, dl, VT, Q, T);
@ -2662,9 +2661,9 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
/// return a DAG expression to select that will generate the same value by
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue TargetLowering::
BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode*> *Created) const {
SDValue TargetLowering::BuildUDIV(SDNode *N, const APInt &Divisor,
SelectionDAG &DAG, bool IsAfterLegalization,
std::vector<SDNode *> *Created) const {
EVT VT = N->getValueType(0);
SDLoc dl(N);
@ -2675,22 +2674,21 @@ BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
// FIXME: We should use a narrower constant when the upper
// bits are known to be zero.
const APInt &N1C = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
APInt::mu magics = N1C.magicu();
APInt::mu magics = Divisor.magicu();
SDValue Q = N->getOperand(0);
// If the divisor is even, we can avoid using the expensive fixup by shifting
// the divided value upfront.
if (magics.a != 0 && !N1C[0]) {
unsigned Shift = N1C.countTrailingZeros();
if (magics.a != 0 && !Divisor[0]) {
unsigned Shift = Divisor.countTrailingZeros();
Q = DAG.getNode(ISD::SRL, dl, VT, Q,
DAG.getConstant(Shift, getShiftAmountTy(Q.getValueType())));
if (Created)
Created->push_back(Q.getNode());
// Get magic number for the shifted divisor.
magics = N1C.lshr(Shift).magicu(Shift);
magics = Divisor.lshr(Shift).magicu(Shift);
assert(magics.a == 0 && "Should use cheap fixup now");
}
@ -2709,7 +2707,7 @@ BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
Created->push_back(Q.getNode());
if (magics.a == 0) {
assert(magics.s < N1C.getBitWidth() &&
assert(magics.s < Divisor.getBitWidth() &&
"We shouldn't generate an undefined shift!");
return DAG.getNode(ISD::SRL, dl, VT, Q,
DAG.getConstant(magics.s, getShiftAmountTy(Q.getValueType())));

View File

@ -446,6 +446,11 @@ ARMTargetLowering::ARMTargetLowering(TargetMachine &TM)
setLoadExtAction(ISD::SEXTLOAD, (MVT::SimpleValueType)VT, Expand);
setLoadExtAction(ISD::ZEXTLOAD, (MVT::SimpleValueType)VT, Expand);
setLoadExtAction(ISD::EXTLOAD, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::MULHS, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::SMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::MULHU, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::UMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
}
setOperationAction(ISD::ConstantFP, MVT::f32, Custom);

View File

@ -435,6 +435,11 @@ ARM64TargetLowering::ARM64TargetLowering(ARM64TargetMachine &TM)
setOperationAction(ISD::SIGN_EXTEND_INREG, (MVT::SimpleValueType)VT,
Expand);
setOperationAction(ISD::MULHS, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::SMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::MULHU, (MVT::SimpleValueType)VT, Expand);
setOperationAction(ISD::UMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
for (unsigned InnerVT = (unsigned)MVT::FIRST_VECTOR_VALUETYPE;
InnerVT <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++InnerVT)
setTruncStoreAction((MVT::SimpleValueType)VT,

View File

@ -0,0 +1,45 @@
; RUN: llc -march=x86-64 -mcpu=core2 < %s | FileCheck %s -check-prefix=SSE
; RUN: llc -march=x86-64 -mcpu=core-avx2 < %s | FileCheck %s -check-prefix=AVX
define <4 x i32> @test1(<4 x i32> %a) {
%div = udiv <4 x i32> %a, <i32 7, i32 7, i32 7, i32 7>
ret <4 x i32> %div
; SSE-LABEL: test1:
; SSE: pmuludq
; SSE: pshufd $57
; SSE: pmuludq
; SSE: shufps $-35
; SSE: psubd
; SSE: psrld $1
; SSE: padd
; SSE: psrld $2
; AVX-LABEL: test1:
; AVX: vpmuludq
; AVX: vpshufd $57
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpsubd
; AVX: vpsrld $1
; AVX: vpadd
; AVX: vpsrld $2
}
define <8 x i32> @test2(<8 x i32> %a) {
%div = udiv <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
ret <8 x i32> %div
; AVX-LABEL: test2:
; AVX: vpermd
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpsubd
; AVX: vpsrld $1
; AVX: vpadd
; AVX: vpsrld $2
}
; TODO: sdiv -> pmuldq