diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 149 |
1 files changed, 100 insertions, 49 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 10f05bc8dd..cbfc56373a 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -328,21 +328,21 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, } -// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular +// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular // input. Don't use a SCEVHandle here, or else the object will never be // deleted! static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, - SCEVSDivExpr*> > SCEVSDivs; + SCEVUDivExpr*> > SCEVUDivs; -SCEVSDivExpr::~SCEVSDivExpr() { - SCEVSDivs->erase(std::make_pair(LHS, RHS)); +SCEVUDivExpr::~SCEVUDivExpr() { + SCEVUDivs->erase(std::make_pair(LHS, RHS)); } -void SCEVSDivExpr::print(std::ostream &OS) const { - OS << "(" << *LHS << " /s " << *RHS << ")"; +void SCEVUDivExpr::print(std::ostream &OS) const { + OS << "(" << *LHS << " /u " << *RHS << ")"; } -const Type *SCEVSDivExpr::getType() const { +const Type *SCEVUDivExpr::getType() const { return LHS->getType(); } @@ -532,57 +532,110 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, } -/// PartialFact - Compute V!/(V-NumSteps)! -static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps, - ScalarEvolution &SE) { +/// BinomialCoefficient - Compute BC(It, K). The result is of the same type as +/// It. Assume, K > 0. +static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, + ScalarEvolution &SE) { + // We are using the following formula for BC(It, K): + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! + // + // Suppose, W is the bitwidth of It (and of the return value as well). We + // must be prepared for overflow. Hence, we must assure that the result of + // our computation is equal to the accurate one modulo 2^W. Unfortunately, + // division isn't safe in modular arithmetic. This means we must perform the + // whole computation accurately and then truncate the result to W bits. + // + // The dividend of the formula is a multiplication of K integers of bitwidth + // W. K*W bits suffice to compute it accurately. + // + // FIXME: We assume the divisor can be accurately computed using 16-bit + // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In + // future we may use APInt to use the minimum number of bits necessary to + // compute it accurately. + // + // It is safe to use unsigned division here: the dividend is nonnegative and + // the divisor is positive. + + // Handle the simplest case efficiently. + if (K == 1) + return It; + + assert(K < 9 && "We cannot handle such long AddRecs yet."); + + // FIXME: A temporary hack to remove in future. Arbitrary precision integers + // aren't supported by the code generator yet. For the dividend, the bitwidth + // we use is the smallest power of 2 greater or equal to K*W and less or equal + // to 64. Note that setting the upper bound for bitwidth may still lead to + // miscompilation in some cases. + unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth()); + if (DividendBits > 64) + DividendBits = 64; +#if 0 // Waiting for the APInt support in the code generator... + unsigned DividendBits = K * It->getBitWidth(); +#endif + + const IntegerType *DividendTy = IntegerType::get(DividendBits); + const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy); + + // The final number of bits we need to perform the division is the maximum of + // dividend and divisor bitwidths. + const IntegerType *DivisionTy = + IntegerType::get(std::max(DividendBits, 16U)); + + // Compute K! We know K >= 2 here. + unsigned F = 2; + for (unsigned i = 3; i <= K; ++i) + F *= i; + APInt Divisor(DivisionTy->getBitWidth(), F); + // Handle this case efficiently, it is common to have constant iteration // counts while computing loop exit values. - if (SCEVConstant *SC = dyn_cast<SCEVConstant>(V)) { - const APInt& Val = SC->getValue()->getValue(); - APInt Result(Val.getBitWidth(), 1); - for (; NumSteps; --NumSteps) - Result *= Val-(NumSteps-1); - return SE.getConstant(Result); + if (SCEVConstant *SC = dyn_cast<SCEVConstant>(ExIt)) { + const APInt& N = SC->getValue()->getValue(); + APInt Dividend(N.getBitWidth(), 1); + for (; K; --K) + Dividend *= N-(K-1); + if (DividendTy != DivisionTy) + Dividend = Dividend.zext(DivisionTy->getBitWidth()); + return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth())); } - - const Type *Ty = V->getType(); - if (NumSteps == 0) - return SE.getIntegerSCEV(1, Ty); - - SCEVHandle Result = V; - for (unsigned i = 1; i != NumSteps; ++i) - Result = SE.getMulExpr(Result, SE.getMinusSCEV(V, - SE.getIntegerSCEV(i, Ty))); - return Result; + + SCEVHandle Dividend = ExIt; + for (unsigned i = 1; i != K; ++i) + Dividend = + SE.getMulExpr(Dividend, + SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy))); + if (DividendTy != DivisionTy) + Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy); + return + SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)), + It->getType()); } - /// evaluateAtIteration - Return the value of this chain of recurrences at /// the specified iteration number. We can evaluate this recurrence by /// multiplying each element in the chain by the binomial coefficient /// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: /// -/// A*choose(It, 0) + B*choose(It, 1) + C*choose(It, 2) + D*choose(It, 3) +/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// -/// FIXME/VERIFY: I don't trust that this is correct in the face of overflow. -/// Is the binomial equation safe using modular arithmetic?? +/// where BC(It, k) stands for binomial coefficient. /// SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, ScalarEvolution &SE) const { SCEVHandle Result = getStart(); - int Divisor = 1; - const Type *Ty = It->getType(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { - SCEVHandle BC = PartialFact(It, i, SE); - Divisor *= i; - SCEVHandle Val = SE.getSDivExpr(SE.getMulExpr(BC, getOperand(i)), - SE.getIntegerSCEV(Divisor,Ty)); + // The computation is correct in the face of overflow provided that the + // multiplication is performed _after_ the evaluation of the binomial + // coefficient. + SCEVHandle Val = SE.getMulExpr(getOperand(i), + BinomialCoefficient(It, i, SE)); Result = SE.getAddExpr(Result, Val); } return Result; } - //===----------------------------------------------------------------------===// // SCEV Expression folder implementations //===----------------------------------------------------------------------===// @@ -1039,24 +1092,22 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { return Result; } -SCEVHandle ScalarEvolution::getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { +SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X sdiv 1 --> x - if (RHSC->getValue()->isAllOnesValue()) - return getNegativeSCEV(LHS); // X sdiv -1 --> -x + return LHS; // X udiv 1 --> x if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); - return getUnknown(ConstantExpr::getSDiv(LHSCV, RHSCV)); + return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV)); } } // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow. - SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS); + SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); return Result; } @@ -1555,7 +1606,7 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { return MinOpRes; } - // SCEVSDivExpr, SCEVUnknown + // SCEVUDivExpr, SCEVUnknown return 0; } @@ -1574,8 +1625,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::Mul: return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(I->getOperand(1))); - case Instruction::SDiv: - return SE.getSDivExpr(getSCEV(I->getOperand(0)), + case Instruction::UDiv: + return SE.getUDivExpr(getSCEV(I->getOperand(0)), getSCEV(I->getOperand(1))); case Instruction::Sub: return SE.getMinusSCEV(getSCEV(I->getOperand(0)), @@ -2264,14 +2315,14 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return Comm; } - if (SCEVSDivExpr *Div = dyn_cast<SCEVSDivExpr>(V)) { + if (SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); if (LHS == UnknownValue) return LHS; SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); if (RHS == UnknownValue) return RHS; if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant - return SE.getSDivExpr(LHS, RHS); + return SE.getUDivExpr(LHS, RHS); } // If this is a loop recurrence for a loop that does not contain L, then we |