1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 12:12:47 +01:00

[ScalarEvolution] Fix pointer/int type handling converting select/phi to min/max.

The old version of this code would blindly perform arithmetic without
paying attention to whether the types involved were pointers or
integers.  This could lead to weird expressions like negating a pointer.

Explicitly handle simple cases involving pointers, like "x < y ? x : y".
In all other cases, coerce the operands of the comparison to integer
types.  This avoids the weird cases, while handling most of the
interesting cases.

Differential Revision: https://reviews.llvm.org/D103660
This commit is contained in:
Eli Friedman 2021-06-16 00:00:13 -07:00
parent 0c385efcae
commit e4884552df
4 changed files with 46 additions and 36 deletions

View File

@ -5546,12 +5546,35 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
// a > b ? b+x : a+x -> min(a, b)+x
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
bool Signed = ICI->isSigned();
const SCEV *LS = Signed ? getNoopOrSignExtend(getSCEV(LHS), I->getType())
: getNoopOrZeroExtend(getSCEV(LHS), I->getType());
const SCEV *RS = Signed ? getNoopOrSignExtend(getSCEV(RHS), I->getType())
: getNoopOrZeroExtend(getSCEV(RHS), I->getType());
const SCEV *LA = getSCEV(TrueVal);
const SCEV *RA = getSCEV(FalseVal);
const SCEV *LS = getSCEV(LHS);
const SCEV *RS = getSCEV(RHS);
if (LA->getType()->isPointerTy()) {
// FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
// Need to make sure we can't produce weird expressions involving
// negated pointers.
if (LA == LS && RA == RS)
return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
if (LA == RS && RA == LS)
return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
}
auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
if (Op->getType()->isPointerTy()) {
Op = getLosslessPtrToIntExpr(Op);
if (isa<SCEVCouldNotCompute>(Op))
return Op;
}
if (Signed)
Op = getNoopOrSignExtend(Op, I->getType());
else
Op = getNoopOrZeroExtend(Op, I->getType());
return Op;
};
LS = CoerceOperand(LS);
RS = CoerceOperand(RS);
if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
break;
const SCEV *LDiff = getMinusSCEV(LA, LS);
const SCEV *RDiff = getMinusSCEV(RA, RS);
if (LDiff == RDiff)

View File

@ -16,11 +16,11 @@ define i8* @FSE_decompress_usingDTable(i8* %arg, i32 %arg1, i32 %arg2, i32 %arg3
; CHECK-NEXT: %i5 = getelementptr inbounds i8, i8* %i, i32 %i4
; CHECK-NEXT: --> ((-1 * %arg1) + %arg2 + %arg) U: full-set S: full-set
; CHECK-NEXT: %i7 = select i1 %i6, i32 %arg2, i32 %arg1
; CHECK-NEXT: --> ((-1 * %arg) + (((-1 * %arg1) + %arg2 + %arg) umin %arg) + %arg1) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %arg to i32)) + (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32)) + %arg1) U: full-set S: full-set
; CHECK-NEXT: %i8 = sub i32 %arg3, %i7
; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3 + %arg) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3) U: full-set S: full-set
; CHECK-NEXT: %i9 = getelementptr inbounds i8, i8* %arg, i32 %i8
; CHECK-NEXT: --> ((2 * %arg) + (-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3 + %arg) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @FSE_decompress_usingDTable
;
bb:
@ -42,11 +42,11 @@ define i8* @test_01(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umax (2 + %p))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
; CHECK-NEXT: --> ((-1 * ((1 + %p) umax (2 + %p))) + %p) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umax (2 + %p)))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_01
;
%p1 = getelementptr i8, i8* %p, i32 2
@ -66,11 +66,11 @@ define i8* @test_02(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smax (2 + %p))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
; CHECK-NEXT: --> ((-1 * ((1 + %p) smax (2 + %p))) + %p) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smax (2 + %p)))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_02
;
%p1 = getelementptr i8, i8* %p, i32 2
@ -90,11 +90,11 @@ define i8* @test_03(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umin (2 + %p))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
; CHECK-NEXT: --> ((-1 * ((1 + %p) umin (2 + %p))) + %p) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umin (2 + %p)))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_03
;
%p1 = getelementptr i8, i8* %p, i32 2
@ -114,11 +114,11 @@ define i8* @test_04(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smin (2 + %p))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
; CHECK-NEXT: --> ((-1 * ((1 + %p) smin (2 + %p))) + %p) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smin (2 + %p)))) U: full-set S: full-set
; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_04
;
%p1 = getelementptr i8, i8* %p, i32 2

View File

@ -10,7 +10,7 @@ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16
define internal fastcc void @d(i8* %c) unnamed_addr #0 {
entry:
%cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535)
%cmp = icmp ule i8* %c, @a
%add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535
br label %while.cond
@ -18,7 +18,7 @@ while.cond:
br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end
cont:
%a.mux = select i1 %cmp, i8* @a, i8* %add.ptr
%a.mux = select i1 %cmp, i8* @a, i8* %c
switch i64 0, label %while.cond [
i64 -1, label %handler.pointer_overflow.i
i64 0, label %handler.pointer_overflow.i
@ -26,7 +26,7 @@ cont:
handler.pointer_overflow.i:
%a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ]
; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ]
; ALWAYS: [ %umax, %cont ], [ %umax, %cont ]
; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ]
; In cheap mode, use either one as long as it's consistent.
; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ]

View File

@ -118,20 +118,7 @@ TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
ScalarEvolution SE = buildSE(*F);
auto *S = SE.getSCEV(CastB);
SCEVExpander Exp(SE, M.getDataLayout(), "expander");
Value *V =
Exp.expandCodeFor(cast<SCEVAddExpr>(S)->getOperand(1), nullptr, Br);
// Expect the expansion code contains:
// %0 = bitcast i32* %bitcast2 to i8*
// %uglygep = getelementptr i8, i8* %0, i64 -1
// %1 = bitcast i8* %uglygep to i32*
EXPECT_TRUE(isa<BitCastInst>(V));
Instruction *Gep = cast<Instruction>(V)->getPrevNode();
EXPECT_TRUE(isa<GetElementPtrInst>(Gep));
EXPECT_TRUE(isa<ConstantInt>(Gep->getOperand(1)));
EXPECT_EQ(cast<ConstantInt>(Gep->getOperand(1))->getSExtValue(), -1);
EXPECT_TRUE(isa<BitCastInst>(Gep->getPrevNode()));
EXPECT_TRUE(isa<SCEVUnknown>(S));
}
// Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions