From 0b73344d8a6d05842ed352aa54614ad79ab4d407 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Wed, 6 Jan 2010 01:56:21 +0000 Subject: [PATCH] Teach instcombine's sext elimination logic to be more aggressive. Previously, instcombine would only promote an expression tree to the larger type if doing so eliminated two casts. This is because a need to manually do the sign extend after the promoted expression tree with two shifts. Now, we keep track of whether the result of the computation is going to be properly sign extended already. If so, we can unconditionally promote the expression, which allows us to zap more sext's. This implements rdar://6598839 (aka gcc pr38751) llvm-svn: 92815 --- .../InstCombine/InstCombineCasts.cpp | 186 ++++++++++++++++-- test/Transforms/InstCombine/cast-sext-zext.ll | 21 -- test/Transforms/InstCombine/cast.ll | 11 ++ 3 files changed, 179 insertions(+), 39 deletions(-) delete mode 100644 test/Transforms/InstCombine/cast-sext-zext.ll diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 336c32940be..acd78d6bf9f 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -153,7 +153,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// whether promoting or shrinking integer operations to wider or smaller types /// will allow us to eliminate a truncate or extend. /// -/// This is a truncation operation if Ty is smaller than V->getType(), or an +/// This is a truncation operation if Ty is smaller than V->getType(), or a zero /// extension operation if Ty is larger. /// /// If CastOpc is a truncation, then Ty will be a type smaller than V. We @@ -162,11 +162,13 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// inst(trunc(x),trunc(y)), which only makes sense if x and y can be /// efficiently truncated. /// -/// If CastOpc is a sext or zext, we are asking if the low bits of the value can -/// bit computed in a larger type, which is then and'd or sext_in_reg'd to get -/// the final result. +/// If CastOpc is zext, we are asking if the low bits of the value can bit +/// computed in a larger type, which is then and'd to get the final result. static bool CanEvaluateInDifferentType(Value *V, const Type *Ty, - unsigned CastOpc, int &NumCastsRemoved) { + unsigned CastOpc, + unsigned &NumCastsRemoved) { + assert(CastOpc == Instruction::ZExt || CastOpc == Instruction::Trunc); + // We can always evaluate constants in another type. if (isa(V)) return true; @@ -291,9 +293,124 @@ static bool CanEvaluateInDifferentType(Value *V, const Type *Ty, return false; } +/// CanEvaluateSExtd - Return true if we can take the specified value +/// and return it as type Ty without inserting any new casts and without +/// changing the computed value of the common low bits. This is used by code +/// that tries to promote integer operations to a wider types will allow us to +/// eliminate the extension. +/// +/// This returns 0 if we can't do this or the number of sign bits that would be +/// set if we can. For example, CanEvaluateSExtd(i16 1, i64) would return 63, +/// because the computation can be extended (to "i64 1") and the resulting +/// computation has 63 equal sign bits. +/// +/// This function works on both vectors and scalars. For vectors, the result is +/// the number of bits known sign extended in each element. +/// +static unsigned CanEvaluateSExtd(Value *V, const Type *Ty, + unsigned &NumCastsRemoved, TargetData *TD) { + assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && + "Can't sign extend type to a smaller type"); + // If this is a constant, return the number of sign bits the extended version + // of it would have. + if (Constant *C = dyn_cast(V)) + return ComputeNumSignBits(ConstantExpr::getSExt(C, Ty), TD); + + Instruction *I = dyn_cast(V); + if (!I) return 0; + + // If this is a truncate from the destination type, we can trivially eliminate + // it, and this will remove a cast overall. + if (isa(I) && I->getOperand(0)->getType() == Ty) { + // If the operand of the truncate is itself a cast, and is eliminable, do + // not count this as an eliminable cast. We would prefer to eliminate those + // two casts first. + if (!isa(I->getOperand(0)) && I->hasOneUse()) + ++NumCastsRemoved; + return ComputeNumSignBits(I->getOperand(0), TD); + } + + // We can't extend or shrink something that has multiple uses: doing so would + // require duplicating the instruction in general, which isn't profitable. + if (!I->hasOneUse()) return 0; + + const Type *OrigTy = V->getType(); + + unsigned Opc = I->getOpcode(); + unsigned Tmp1, Tmp2; + switch (Opc) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // These operators can all arbitrarily be extended or truncated. + Tmp1 = CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD); + return std::min(Tmp1, Tmp2); + case Instruction::Add: + case Instruction::Sub: + // Add/Sub can have at most one carry/borrow bit. + Tmp1 = CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD); + if (Tmp2 == 0) return 0; + return std::min(Tmp1, Tmp2)-1; + case Instruction::Mul: + // These operators can all arbitrarily be extended or truncated. + if (!CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD)) + return 0; + if (!CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD)) + return 0; + return 1; // IMPROVE? + + //case Instruction::Shl: TODO + //case Instruction::LShr: TODO + //case Instruction::Trunc: TODO + + case Instruction::SExt: + case Instruction::ZExt: { + // sext(sext(x)) -> sext(x) + // sext(zext(x)) -> zext(x) + // Note that replacing a cast does not reduce the number of casts in the + // input. + unsigned InSignBits = ComputeNumSignBits(I, TD); + unsigned ExtBits = Ty->getScalarSizeInBits()-OrigTy->getScalarSizeInBits(); + // We'll end up extending it all the way out. + return InSignBits+ExtBits; + } + case Instruction::Select: { + SelectInst *SI = cast(I); + Tmp1 = CanEvaluateSExtd(SI->getTrueValue(), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(SI->getFalseValue(), Ty, NumCastsRemoved,TD); + return std::min(Tmp1, Tmp2); + } + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast(I); + unsigned Result = ~0U; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Result = std::min(Result, + CanEvaluateSExtd(PN->getIncomingValue(i), Ty, + NumCastsRemoved, TD)); + if (Result == 0) return 0; + } + return Result; + } + default: + // TODO: Can handle more cases here. + break; + } + + return 0; +} + + /// EvaluateInDifferentType - Given an expression that -/// CanEvaluateInDifferentType returns true for, actually insert the code to -/// evaluate the expression. +/// CanEvaluateInDifferentType or CanEvaluateSExtd returns true for, actually +/// insert the code to evaluate the expression. Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned) { if (Constant *C = dyn_cast(V)) @@ -469,35 +586,68 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { return 0; // Attempt to propagate the cast into the instruction for int->int casts. - int NumCastsRemoved = 0; - if (!CanEvaluateInDifferentType(Src, DestTy, CI.getOpcode(), NumCastsRemoved)) - return 0; - + unsigned NumCastsRemoved = 0; switch (CI.getOpcode()) { default: assert(0 && "not an integer cast"); case Instruction::Trunc: + if (!CanEvaluateInDifferentType(Src, DestTy, + Instruction::Trunc, NumCastsRemoved)) + return 0; + // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. break; case Instruction::ZExt: + if (!CanEvaluateInDifferentType(Src, DestTy, + Instruction::ZExt, NumCastsRemoved)) + return 0; + // If this is a zero-extension, we need to do an AND to maintain the clear // top-part of the computation, so we require that the input have eliminated // at least one cast. if (NumCastsRemoved < 1) return 0; break; - case Instruction::SExt: - // If this is a sign extension, we insert two new shifts (to do the - // extension) so we require that two casts have been eliminated. - if (NumCastsRemoved < 2) + case Instruction::SExt: { + // Check to see if we can do this transformation, and if so, how many bits + // of the promoted expression will be known copies of the sign bit in the + // result. + unsigned NumBitsSExt = CanEvaluateSExtd(Src, DestTy, NumCastsRemoved, TD); + if (NumBitsSExt == 0) return 0; - break; + + uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); + uint32_t DestBitSize = DestTy->getScalarSizeInBits(); + + // Because this is a sign extension, we can always transform it by inserting + // two new shifts (to do the extension). However, this is only profitable + // if we've eliminated two or more casts from the input. If we know the + // result will be sign-extendy enough to not require these shifts, we can + // always do the transformation. + if (NumCastsRemoved < 2 && + NumBitsSExt <= DestBitSize-SrcBitSize) + return 0; + + // Okay, we can transform this! Insert the new expression now. + DEBUG(errs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid sign extend: " << CI); + Value *Res = EvaluateInDifferentType(Src, DestTy, true); + assert(Res->getType() == DestTy); + + // If the high bits are already filled with sign bit, just replace this + // cast with the result. + if (NumBitsSExt > DestBitSize - SrcBitSize || + ComputeNumSignBits(Res) > DestBitSize - SrcBitSize) + return ReplaceInstUsesWith(CI, Res); + + // We need to emit a cast to truncate, then a cast to sext. + return new SExtInst(Builder->CreateTrunc(Res, Src->getType()), DestTy); + } } DEBUG(errs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid cast: " << CI); - Value *Res = EvaluateInDifferentType(Src, DestTy, - CI.getOpcode() == Instruction::SExt); + Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); diff --git a/test/Transforms/InstCombine/cast-sext-zext.ll b/test/Transforms/InstCombine/cast-sext-zext.ll deleted file mode 100644 index 678874a2794..00000000000 --- a/test/Transforms/InstCombine/cast-sext-zext.ll +++ /dev/null @@ -1,21 +0,0 @@ -; RUN: opt < %s -instcombine -S | not grep sext -; XFAIL: * -; rdar://6598839 - -define zeroext i16 @t(i8 zeroext %on_off, i16* nocapture %puls) nounwind readonly { -entry: - %0 = zext i8 %on_off to i32 - %1 = add i32 %0, -1 - %2 = sext i32 %1 to i64 - %3 = getelementptr i16* %puls, i64 %2 - %4 = load i16* %3, align 2 - ret i16 %4 -} - -define zeroext i64 @t2(i8 zeroext %on_off) nounwind readonly { -entry: - %0 = zext i8 %on_off to i32 - %1 = add i32 %0, -1 - %2 = sext i32 %1 to i64 - ret i64 %2 ;; Should be (add (zext i8 -> i64), -1) -} diff --git a/test/Transforms/InstCombine/cast.ll b/test/Transforms/InstCombine/cast.ll index a6c6795e844..10e5050125d 100644 --- a/test/Transforms/InstCombine/cast.ll +++ b/test/Transforms/InstCombine/cast.ll @@ -381,3 +381,14 @@ define i32 @test42(i32 %X) { ; CHECK: %Z = and i32 %X, 255 } +; rdar://6598839 +define zeroext i64 @test43(i8 zeroext %on_off) nounwind readonly { + %A = zext i8 %on_off to i32 + %B = add i32 %A, -1 + %C = sext i32 %B to i64 + ret i64 %C ;; Should be (add (zext i8 -> i64), -1) +; CHECK: @test43 +; CHECK-NEXT: %A = zext i8 %on_off to i64 +; CHECK-NEXT: %B = add i64 %A, -1 +; CHECK-NEXT: ret i64 %B +}