1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2025-02-01 05:01:59 +01:00

[DAGCombiner] Better support for shifting large value type by constants

As detailed on D22726, much of the shift combining code assume constant values will fit into a uint64_t value and calls ConstantSDNode::getZExtValue where it probably shouldn't (leading to asserts). Using APInt directly avoids this problem but we encounter other assertions if we attempt to compare/operate on 2 APInt of different bitwidths.

This patch adds a helper function to ensure that 2 APInt values are zero extended as required so that they can be safely used together. I've only added an initial example use for this to the '(SHIFT (SHIFT x, c1), c2) --> (SHIFT x, (ADD c1, c2))' combines. Further cases can easily be added as required.

Differential Revision: https://reviews.llvm.org/D23007

llvm-svn: 278141
This commit is contained in:
Simon Pilgrim 2016-08-09 17:39:11 +00:00
parent db9853f118
commit 9d3d169916
2 changed files with 66 additions and 17 deletions

View File

@ -726,6 +726,15 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
}
}
// APInts must be the same size for most operations, this helper
// function zero extends the shorter of the pair so that they match.
// We provide an Offset so that we can create bitwidths that won't overflow.
static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
LHS = LHS.zextOrSelf(Bits);
RHS = RHS.zextOrSelf(Bits);
}
// Return true if this node is a setcc, or is a select_cc
// that selects between the target values used for true and false, making it
// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
@ -4464,13 +4473,18 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
if (N1C && N0.getOpcode() == ISD::SHL) {
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
uint64_t c1 = N0C1->getZExtValue();
uint64_t c2 = N1C->getZExtValue();
SDLoc DL(N);
if (c1 + c2 >= OpSizeInBits)
APInt c1 = N0C1->getAPIntValue();
APInt c2 = N1C->getAPIntValue();
zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
APInt Sum = c1 + c2;
if (Sum.uge(OpSizeInBits))
return DAG.getConstant(0, DL, VT);
return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
DAG.getConstant(c1 + c2, DL, N1.getValueType()));
return DAG.getNode(
ISD::SHL, DL, VT, N0.getOperand(0),
DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
}
}
@ -4656,13 +4670,19 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
// fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
if (N1C && N0.getOpcode() == ISD::SRA) {
if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) {
unsigned Sum = N1C->getZExtValue() + C1->getZExtValue();
if (Sum >= OpSizeInBits)
Sum = OpSizeInBits - 1;
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
SDLoc DL(N);
return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
DAG.getConstant(Sum, DL, N1.getValueType()));
APInt c1 = N0C1->getAPIntValue();
APInt c2 = N1C->getAPIntValue();
zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
APInt Sum = c1 + c2;
if (Sum.uge(OpSizeInBits))
Sum = APInt(OpSizeInBits, OpSizeInBits - 1);
return DAG.getNode(
ISD::SRA, DL, VT, N0.getOperand(0),
DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
}
}
@ -4790,14 +4810,19 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
if (N1C && N0.getOpcode() == ISD::SRL) {
if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) {
uint64_t c1 = N01C->getZExtValue();
uint64_t c2 = N1C->getZExtValue();
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
SDLoc DL(N);
if (c1 + c2 >= OpSizeInBits)
APInt c1 = N0C1->getAPIntValue();
APInt c2 = N1C->getAPIntValue();
zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
APInt Sum = c1 + c2;
if (Sum.uge(OpSizeInBits))
return DAG.getConstant(0, DL, VT);
return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
DAG.getConstant(c1 + c2, DL, N1.getValueType()));
return DAG.getNode(
ISD::SRL, DL, VT, N0.getOperand(0),
DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
}
}

View File

@ -92,3 +92,27 @@ entry:
store <2 x i128> %0, <2 x i128>* %r, align 16
ret void
}
define void @test_lshr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
entry:
%0 = lshr <2 x i128> %x, <i128 -1, i128 -1>
%1 = lshr <2 x i128> %0, <i128 1, i128 1>
store <2 x i128> %1, <2 x i128>* %r, align 16
ret void
}
define void @test_ashr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
entry:
%0 = ashr <2 x i128> %x, <i128 -1, i128 -1>
%1 = ashr <2 x i128> %0, <i128 1, i128 1>
store <2 x i128> %1, <2 x i128>* %r, align 16
ret void
}
define void @test_shl_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
entry:
%0 = shl <2 x i128> %x, <i128 -1, i128 -1>
%1 = shl <2 x i128> %0, <i128 1, i128 1>
store <2 x i128> %1, <2 x i128>* %r, align 16
ret void
}