aboutsummaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorDan Gohman <gohman@apple.com>2008-06-22 19:56:46 +0000
committerDan Gohman <gohman@apple.com>2008-06-22 19:56:46 +0000
commit6c459a28ecb8d33e4b59ab2db1f9a58a2d06824b (patch)
treebc8761d2453c0a5ad8b4487ab645e904dbfff880 /lib/Analysis/ScalarEvolution.cpp
parent17f1972c770dc18f5c7c3c95776b4d62ae9e121d (diff)
Generalize createSCEV to be able to form SCEV expressions from
ConstantExprs. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@52615 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp227
1 files changed, 117 insertions, 110 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 36da85bc06..d615c752b0 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -1704,118 +1704,125 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
if (!isa<IntegerType>(V->getType()))
return SE.getUnknown(V);
- if (Instruction *I = dyn_cast<Instruction>(V)) {
- switch (I->getOpcode()) {
- case Instruction::Add:
- return SE.getAddExpr(getSCEV(I->getOperand(0)),
- getSCEV(I->getOperand(1)));
- case Instruction::Mul:
- return SE.getMulExpr(getSCEV(I->getOperand(0)),
- getSCEV(I->getOperand(1)));
- case Instruction::UDiv:
- return SE.getUDivExpr(getSCEV(I->getOperand(0)),
- getSCEV(I->getOperand(1)));
- case Instruction::Sub:
- return SE.getMinusSCEV(getSCEV(I->getOperand(0)),
- getSCEV(I->getOperand(1)));
- case Instruction::Or:
- // If the RHS of the Or is a constant, we may have something like:
- // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
- // optimizations will transparently handle this case.
- //
- // In order for this transformation to be safe, the LHS must be of the
- // form X*(2^n) and the Or constant must be less than 2^n.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
- SCEVHandle LHS = getSCEV(I->getOperand(0));
- const APInt &CIVal = CI->getValue();
- if (GetMinTrailingZeros(LHS) >=
- (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
- return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
- }
- break;
- case Instruction::Xor:
- // If the RHS of the xor is a signbit, then this is just an add.
- // Instcombine turns add of signbit into xor as a strength reduction step.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
- if (CI->getValue().isSignBit())
- return SE.getAddExpr(getSCEV(I->getOperand(0)),
- getSCEV(I->getOperand(1)));
- else if (CI->isAllOnesValue())
- return SE.getNotSCEV(getSCEV(I->getOperand(0)));
- }
- break;
-
- case Instruction::Shl:
- // Turn shift left of a constant amount into a multiply.
- if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
- uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
- Constant *X = ConstantInt::get(
- APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
- return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X));
- }
- break;
-
- case Instruction::Trunc:
- return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType());
-
- case Instruction::ZExt:
- return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType());
-
- case Instruction::SExt:
- return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType());
-
- case Instruction::BitCast:
- // BitCasts are no-op casts so we just eliminate the cast.
- if (I->getType()->isInteger() &&
- I->getOperand(0)->getType()->isInteger())
- return getSCEV(I->getOperand(0));
- break;
-
- case Instruction::PHI:
- return createNodeForPHI(cast<PHINode>(I));
-
- case Instruction::Select:
- // This could be a smax or umax that was lowered earlier.
- // Try to recover it.
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
- switch (ICI->getPredicate()) {
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE:
- if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
- return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
- else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
- // -smax(-x, -y) == smin(x, y).
- return SE.getNegativeSCEV(SE.getSMaxExpr(
- SE.getNegativeSCEV(getSCEV(LHS)),
- SE.getNegativeSCEV(getSCEV(RHS))));
- break;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE:
- if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
- return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
- else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
- // ~umax(~x, ~y) == umin(x, y)
- return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
- SE.getNotSCEV(getSCEV(RHS))));
- break;
- default:
- break;
- }
- }
+ unsigned Opcode = Instruction::UserOp1;
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ Opcode = I->getOpcode();
+ else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
+ Opcode = CE->getOpcode();
+ else
+ return SE.getUnknown(V);
- default: // We cannot analyze this expression.
- break;
+ User *U = cast<User>(V);
+ switch (Opcode) {
+ case Instruction::Add:
+ return SE.getAddExpr(getSCEV(U->getOperand(0)),
+ getSCEV(U->getOperand(1)));
+ case Instruction::Mul:
+ return SE.getMulExpr(getSCEV(U->getOperand(0)),
+ getSCEV(U->getOperand(1)));
+ case Instruction::UDiv:
+ return SE.getUDivExpr(getSCEV(U->getOperand(0)),
+ getSCEV(U->getOperand(1)));
+ case Instruction::Sub:
+ return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
+ getSCEV(U->getOperand(1)));
+ case Instruction::Or:
+ // If the RHS of the Or is a constant, we may have something like:
+ // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
+ // optimizations will transparently handle this case.
+ //
+ // In order for this transformation to be safe, the LHS must be of the
+ // form X*(2^n) and the Or constant must be less than 2^n.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
+ SCEVHandle LHS = getSCEV(U->getOperand(0));
+ const APInt &CIVal = CI->getValue();
+ if (GetMinTrailingZeros(LHS) >=
+ (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
+ return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
}
+ break;
+ case Instruction::Xor:
+ // If the RHS of the xor is a signbit, then this is just an add.
+ // Instcombine turns add of signbit into xor as a strength reduction step.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
+ if (CI->getValue().isSignBit())
+ return SE.getAddExpr(getSCEV(U->getOperand(0)),
+ getSCEV(U->getOperand(1)));
+ else if (CI->isAllOnesValue())
+ return SE.getNotSCEV(getSCEV(U->getOperand(0)));
+ }
+ break;
+
+ case Instruction::Shl:
+ // Turn shift left of a constant amount into a multiply.
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
+ uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+ Constant *X = ConstantInt::get(
+ APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
+ return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
+ }
+ break;
+
+ case Instruction::Trunc:
+ return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::ZExt:
+ return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::SExt:
+ return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::BitCast:
+ // BitCasts are no-op casts so we just eliminate the cast.
+ if (U->getType()->isInteger() &&
+ U->getOperand(0)->getType()->isInteger())
+ return getSCEV(U->getOperand(0));
+ break;
+
+ case Instruction::PHI:
+ return createNodeForPHI(cast<PHINode>(U));
+
+ case Instruction::Select:
+ // This could be a smax or umax that was lowered earlier.
+ // Try to recover it.
+ if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+ switch (ICI->getPredicate()) {
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
+ return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
+ else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
+ // -smax(-x, -y) == smin(x, y).
+ return SE.getNegativeSCEV(SE.getSMaxExpr(
+ SE.getNegativeSCEV(getSCEV(LHS)),
+ SE.getNegativeSCEV(getSCEV(RHS))));
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
+ return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
+ else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
+ // ~umax(~x, ~y) == umin(x, y)
+ return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
+ SE.getNotSCEV(getSCEV(RHS))));
+ break;
+ default:
+ break;
+ }
+ }
+
+ default: // We cannot analyze this expression.
+ break;
}
return SE.getUnknown(V);