aboutsummaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorNick Lewycky <nicholas@mxc.ca>2007-11-22 07:59:40 +0000
committerNick Lewycky <nicholas@mxc.ca>2007-11-22 07:59:40 +0000
commit83bb0055fdac3c6234c4178cd429e6a917d06c4e (patch)
treef62fae0e6d4cfbb94636c71a97f94cead9160ff4 /lib/Analysis/ScalarEvolution.cpp
parent4ac0e8da4a819b2db09659262227e8c8f7f1fcc0 (diff)
Instead of calculating constant factors, calculate the number of trailing
bits. Patch from Wojciech Matyjewicz. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@44268 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp101
1 files changed, 47 insertions, 54 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index fed57f9d91..cc6cde2ba2 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -1410,62 +1410,60 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
return SE.getUnknown(PN);
}
-/// GetConstantFactor - Determine the largest constant factor that S has. For
-/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero.
-static APInt GetConstantFactor(SCEVHandle S) {
- if (SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- const APInt& V = C->getValue()->getValue();
- if (!V.isMinValue())
- return V;
- else // Zero is a multiple of everything.
- return APInt::getHighBitsSet(C->getBitWidth(), 1);
- }
+/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
+/// guaranteed to end in (at every loop iteration). It is, at the same time,
+/// the minimum number of times S is divisible by 2. For example, given {4,+,8}
+/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S.
+static uint32_t GetMinTrailingZeros(SCEVHandle S) {
+ if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+ // APInt::countTrailingZeros() returns the number of trailing zeros in its
+ // internal representation, which length may be greater than the represented
+ // value bitwidth. This is why we use a min operation here.
+ return std::min(C->getValue()->getValue().countTrailingZeros(),
+ C->getBitWidth());
if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
- return GetConstantFactor(T->getOperand()).trunc(
- cast<IntegerType>(T->getType())->getBitWidth());
- if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
- return GetConstantFactor(E->getOperand()).zext(
- cast<IntegerType>(E->getType())->getBitWidth());
- if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
- return GetConstantFactor(E->getOperand()).sext(
- cast<IntegerType>(E->getType())->getBitWidth());
-
+ return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
+
+ if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+ }
+
+ if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+ }
+
if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
- // The result is the min of all operands.
- APInt Res(GetConstantFactor(A->getOperand(0)));
- for (unsigned i = 1, e = A->getNumOperands();
- i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) {
- APInt Tmp(GetConstantFactor(A->getOperand(i)));
- Res = APIntOps::umin(Res, Tmp);
- }
- return Res;
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
}
if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
- // The result is the product of all the operands.
- APInt Res(GetConstantFactor(M->getOperand(0)));
- for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) {
- APInt Tmp(GetConstantFactor(M->getOperand(i)));
- Res *= Tmp;
- }
- return Res;
+ // The result is the sum of all operands results.
+ uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+ uint32_t BitWidth = M->getBitWidth();
+ for (unsigned i = 1, e = M->getNumOperands();
+ SumOpRes != BitWidth && i != e; ++i)
+ SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
+ BitWidth);
+ return SumOpRes;
}
-
+
if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
- // For now, we just handle linear expressions.
- if (A->getNumOperands() == 2) {
- // We want the GCD between the start and the stride value.
- APInt Start(GetConstantFactor(A->getOperand(0)));
- if (Start == 1)
- return Start;
- APInt Stride(GetConstantFactor(A->getOperand(1)));
- return APIntOps::GreatestCommonDivisor(Start, Stride);
- }
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
}
-
- // SCEVSDivExpr, SCEVUnknown.
- return APInt(S->getBitWidth(), 1);
+
+ // SCEVSDivExpr, SCEVUnknown
+ return 0;
}
/// createSCEV - We know that there is no SCEV for the specified value.
@@ -1493,17 +1491,12 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
-
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
SCEVHandle LHS = getSCEV(I->getOperand(0));
- APInt CommonFact(GetConstantFactor(LHS));
- assert(!CommonFact.isMinValue() &&
- "Common factor should at least be 1!");
const APInt &CIVal = CI->getValue();
- if (CommonFact.countTrailingZeros() >=
+ if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
- return SE.getAddExpr(LHS,
- getSCEV(I->getOperand(1)));
+ return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
}
break;
case Instruction::Xor: