1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-24 03:33:20 +01:00

Generalize createSCEV to be able to form SCEV expressions from

ConstantExprs.

llvm-svn: 52615
This commit is contained in:
Dan Gohman 2008-06-22 19:56:46 +00:00
parent 81c83d9a1d
commit 90894ac18b

View File

@ -1704,118 +1704,125 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
if (!isa<IntegerType>(V->getType())) if (!isa<IntegerType>(V->getType()))
return SE.getUnknown(V); return SE.getUnknown(V);
if (Instruction *I = dyn_cast<Instruction>(V)) { unsigned Opcode = Instruction::UserOp1;
switch (I->getOpcode()) { if (Instruction *I = dyn_cast<Instruction>(V))
case Instruction::Add: Opcode = I->getOpcode();
return SE.getAddExpr(getSCEV(I->getOperand(0)), else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
getSCEV(I->getOperand(1))); Opcode = CE->getOpcode();
case Instruction::Mul: else
return SE.getMulExpr(getSCEV(I->getOperand(0)), return SE.getUnknown(V);
getSCEV(I->getOperand(1)));
case Instruction::UDiv:
return SE.getUDivExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::Sub:
return SE.getMinusSCEV(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::Or:
// If the RHS of the Or is a constant, we may have something like:
// X*4+1 which got turned into X*4|1. Handle this as an Add so loop
// optimizations will transparently handle this case.
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
SCEVHandle LHS = getSCEV(I->getOperand(0));
const APInt &CIVal = CI->getValue();
if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
}
break;
case Instruction::Xor:
// If the RHS of the xor is a signbit, then this is just an add.
// Instcombine turns add of signbit into xor as a strength reduction step.
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
if (CI->getValue().isSignBit())
return SE.getAddExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
else if (CI->isAllOnesValue())
return SE.getNotSCEV(getSCEV(I->getOperand(0)));
}
break;
case Instruction::Shl: User *U = cast<User>(V);
// Turn shift left of a constant amount into a multiply. switch (Opcode) {
if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { case Instruction::Add:
uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); return SE.getAddExpr(getSCEV(U->getOperand(0)),
Constant *X = ConstantInt::get( getSCEV(U->getOperand(1)));
APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); case Instruction::Mul:
return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X)); return SE.getMulExpr(getSCEV(U->getOperand(0)),
} getSCEV(U->getOperand(1)));
break; case Instruction::UDiv:
return SE.getUDivExpr(getSCEV(U->getOperand(0)),
case Instruction::Trunc: getSCEV(U->getOperand(1)));
return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType()); case Instruction::Sub:
return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
case Instruction::ZExt: getSCEV(U->getOperand(1)));
return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType()); case Instruction::Or:
// If the RHS of the Or is a constant, we may have something like:
case Instruction::SExt: // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType()); // optimizations will transparently handle this case.
//
case Instruction::BitCast: // In order for this transformation to be safe, the LHS must be of the
// BitCasts are no-op casts so we just eliminate the cast. // form X*(2^n) and the Or constant must be less than 2^n.
if (I->getType()->isInteger() && if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
I->getOperand(0)->getType()->isInteger()) SCEVHandle LHS = getSCEV(U->getOperand(0));
return getSCEV(I->getOperand(0)); const APInt &CIVal = CI->getValue();
break; if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
case Instruction::PHI: return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
return createNodeForPHI(cast<PHINode>(I));
case Instruction::Select:
// This could be a smax or umax that was lowered earlier.
// Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// -smax(-x, -y) == smin(x, y).
return SE.getNegativeSCEV(SE.getSMaxExpr(
SE.getNegativeSCEV(getSCEV(LHS)),
SE.getNegativeSCEV(getSCEV(RHS))));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// ~umax(~x, ~y) == umin(x, y)
return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
SE.getNotSCEV(getSCEV(RHS))));
break;
default:
break;
}
}
default: // We cannot analyze this expression.
break;
} }
break;
case Instruction::Xor:
// If the RHS of the xor is a signbit, then this is just an add.
// Instcombine turns add of signbit into xor as a strength reduction step.
if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
if (CI->getValue().isSignBit())
return SE.getAddExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
else if (CI->isAllOnesValue())
return SE.getNotSCEV(getSCEV(U->getOperand(0)));
}
break;
case Instruction::Shl:
// Turn shift left of a constant amount into a multiply.
if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
Constant *X = ConstantInt::get(
APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
}
break;
case Instruction::Trunc:
return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::ZExt:
return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::SExt:
return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::BitCast:
// BitCasts are no-op casts so we just eliminate the cast.
if (U->getType()->isInteger() &&
U->getOperand(0)->getType()->isInteger())
return getSCEV(U->getOperand(0));
break;
case Instruction::PHI:
return createNodeForPHI(cast<PHINode>(U));
case Instruction::Select:
// This could be a smax or umax that was lowered earlier.
// Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
// -smax(-x, -y) == smin(x, y).
return SE.getNegativeSCEV(SE.getSMaxExpr(
SE.getNegativeSCEV(getSCEV(LHS)),
SE.getNegativeSCEV(getSCEV(RHS))));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
// ~umax(~x, ~y) == umin(x, y)
return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
SE.getNotSCEV(getSCEV(RHS))));
break;
default:
break;
}
}
default: // We cannot analyze this expression.
break;
} }
return SE.getUnknown(V); return SE.getUnknown(V);