diff --git a/include/llvm/Analysis/ScalarEvolutionExpander.h b/include/llvm/Analysis/ScalarEvolutionExpander.h index c3ea383e322..80e0a9d8cf1 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -88,7 +88,8 @@ namespace llvm { /// InsertCastOfTo - Insert a cast of V to the specified type, doing what /// we can to share the casts. - static Value *InsertCastOfTo(Value *V, const Type *Ty); + static Value *InsertCastOfTo(Instruction::CastOps opcode, Value *V, + const Type *Ty); protected: Value *expand(SCEV *S) { @@ -104,8 +105,20 @@ namespace llvm { Value *expandInTy(SCEV *S, const Type *Ty) { Value *V = expand(S); - if (Ty && V->getType() != Ty) - return InsertCastOfTo(V, Ty); + if (Ty && V->getType() != Ty) { + if (isa(Ty) && V->getType()->isInteger()) + return InsertCastOfTo(Instruction::IntToPtr, V, Ty); + else if (Ty->isInteger() && isa(V->getType())) + return InsertCastOfTo(Instruction::PtrToInt, V, Ty); + else if (Ty->getPrimitiveSizeInBits() == + V->getType()->getPrimitiveSizeInBits()) + return InsertCastOfTo(Instruction::BitCast, V, Ty); + else if (Ty->getPrimitiveSizeInBits() > + V->getType()->getPrimitiveSizeInBits()) + return InsertCastOfTo(Instruction::ZExt, V, Ty); + else + return InsertCastOfTo(Instruction::Trunc, V, Ty); + } return V; } @@ -119,7 +132,7 @@ namespace llvm { } Value *visitZeroExtendExpr(SCEVZeroExtendExpr *S) { - Value *V = expandInTy(S->getOperand(),S->getType()->getUnsignedVersion()); + Value *V = expandInTy(S->getOperand(), S->getType()); return CastInst::createZExtOrBitCast(V, S->getType(), "tmp.", InsertPt); } diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index db23a24d606..5e395db5e23 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -19,25 +19,8 @@ using namespace llvm; /// InsertCastOfTo - Insert a cast of V to the specified type, doing what /// we can to share the casts. -Value *SCEVExpander::InsertCastOfTo(Value *V, const Type *Ty) { - // Compute the Cast opcode to use - Instruction::CastOps opcode = Instruction::BitCast; - if (Ty->isIntegral()) { - if (V->getType()->getTypeID() == Type::PointerTyID) - opcode = Instruction::PtrToInt; - else { - unsigned SrcBits = V->getType()->getPrimitiveSizeInBits(); - unsigned DstBits = Ty->getPrimitiveSizeInBits(); - opcode = (SrcBits > DstBits ? Instruction::Trunc : - (SrcBits == DstBits ? Instruction::BitCast : - (V->getType()->isSigned() ? Instruction::SExt : - Instruction::ZExt))); - } - } else if (Ty->isFloatingPoint()) - opcode = Instruction::UIToFP; - else if (Ty->getTypeID() == Type::PointerTyID && V->getType()->isIntegral()) - opcode = Instruction::IntToPtr; - +Value *SCEVExpander::InsertCastOfTo(Instruction::CastOps opcode, Value *V, + const Type *Ty) { // FIXME: keep track of the cast instruction. if (Constant *C = dyn_cast(V)) return ConstantExpr::getCast(opcode, C, Ty); diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index c132aaef927..9426aa1a69d 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -177,7 +177,7 @@ namespace { /// getCastedVersionOf - Return the specified value casted to uintptr_t. /// - Value *getCastedVersionOf(Value *V); + Value *getCastedVersionOf(Instruction::CastOps opcode, Value *V); private: void runOnLoop(Loop *L); bool AddUsersIfInteresting(Instruction *I, Loop *L, @@ -203,19 +203,16 @@ FunctionPass *llvm::createLoopStrengthReducePass(const TargetLowering *TLI) { /// getCastedVersionOf - Return the specified value casted to uintptr_t. This /// assumes that the Value* V is of integer or pointer type only. /// -Value *LoopStrengthReduce::getCastedVersionOf(Value *V) { +Value *LoopStrengthReduce::getCastedVersionOf(Instruction::CastOps opcode, + Value *V) { if (V->getType() == UIntPtrTy) return V; if (Constant *CB = dyn_cast(V)) - if (CB->getType()->isInteger()) - return ConstantExpr::getIntegerCast(CB, UIntPtrTy, - CB->getType()->isSigned()); - else - return ConstantExpr::getPtrToInt(CB, UIntPtrTy); + return ConstantExpr::getCast(opcode, CB, UIntPtrTy); Value *&New = CastedPointers[V]; if (New) return New; - New = SCEVExpander::InsertCastOfTo(V, UIntPtrTy); + New = SCEVExpander::InsertCastOfTo(opcode, V, UIntPtrTy); DeadInsts.insert(cast(New)); return New; } @@ -258,7 +255,8 @@ SCEVHandle LoopStrengthReduce::GetExpressionSCEV(Instruction *Exp, Loop *L) { // Build up the base expression. Insert an LLVM cast of the pointer to // uintptr_t first. - SCEVHandle GEPVal = SCEVUnknown::get(getCastedVersionOf(GEP->getOperand(0))); + SCEVHandle GEPVal = SCEVUnknown::get( + getCastedVersionOf(Instruction::PtrToInt, GEP->getOperand(0))); gep_type_iterator GTI = gep_type_begin(GEP); @@ -273,7 +271,13 @@ SCEVHandle LoopStrengthReduce::GetExpressionSCEV(Instruction *Exp, Loop *L) { GEPVal = SCEVAddExpr::get(GEPVal, SCEVUnknown::getIntegerSCEV(Offset, UIntPtrTy)); } else { - Value *OpVal = getCastedVersionOf(GEP->getOperand(i)); + unsigned GEPOpiBits = + GEP->getOperand(i)->getType()->getPrimitiveSizeInBits(); + unsigned IntPtrBits = UIntPtrTy->getPrimitiveSizeInBits(); + Instruction::CastOps opcode = (GEPOpiBits < IntPtrBits ? + Instruction::SExt : (GEPOpiBits > IntPtrBits ? Instruction::Trunc : + Instruction::BitCast)); + Value *OpVal = getCastedVersionOf(opcode, GEP->getOperand(i)); SCEVHandle Idx = SE->getSCEV(OpVal); uint64_t TypeSize = TD->getTypeSize(GTI.getIndexedType()); @@ -1125,8 +1129,13 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, if (L->contains(User.Inst->getParent())) User.Inst->moveBefore(LatchBlock->getTerminator()); } - if (RewriteOp->getType() != ReplacedTy) - RewriteOp = SCEVExpander::InsertCastOfTo(RewriteOp, ReplacedTy); + if (RewriteOp->getType() != ReplacedTy) { + Instruction::CastOps opcode = Instruction::Trunc; + if (ReplacedTy->getPrimitiveSizeInBits() == + RewriteOp->getType()->getPrimitiveSizeInBits()) + opcode = Instruction::BitCast; + RewriteOp = SCEVExpander::InsertCastOfTo(opcode, RewriteOp, ReplacedTy); + } SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp);