diff options
Diffstat (limited to 'lib/Transforms/Scalar/LoopStrengthReduce.cpp')
-rw-r--r-- | lib/Transforms/Scalar/LoopStrengthReduce.cpp | 206 |
1 files changed, 153 insertions, 53 deletions
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 15b1ee04b0..e20f4be998 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -382,7 +382,8 @@ namespace { // Once we rewrite the code to insert the new IVs we want, update the // operands of Inst to use the new expression 'NewBase', with 'Imm' added // to it. - void RewriteInstructionToUseNewBase(Value *NewBase, SCEVExpander &Rewriter); + void RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, + SCEVExpander &Rewriter); // Sort by the Base field. bool operator<(const BasedUser &BU) const { return Base < BU.Base; } @@ -403,10 +404,10 @@ void BasedUser::dump() const { // Once we rewrite the code to insert the new IVs we want, update the // operands of Inst to use the new expression 'NewBase', with 'Imm' added // to it. -void BasedUser::RewriteInstructionToUseNewBase(Value *NewBase, +void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, SCEVExpander &Rewriter) { if (!isa<PHINode>(Inst)) { - SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(NewBase), Imm); + SCEVHandle NewValSCEV = SCEVAddExpr::get(NewBase, Imm); Value *NewVal = Rewriter.expandCodeFor(NewValSCEV, Inst, OperandValToReplace->getType()); @@ -426,7 +427,7 @@ void BasedUser::RewriteInstructionToUseNewBase(Value *NewBase, // Insert the code into the end of the predecessor block. BasicBlock::iterator InsertPt = PN->getIncomingBlock(i)->getTerminator(); - SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(NewBase), Imm); + SCEVHandle NewValSCEV = SCEVAddExpr::get(NewBase, Imm); Value *NewVal = Rewriter.expandCodeFor(NewValSCEV, InsertPt, OperandValToReplace->getType()); @@ -552,6 +553,73 @@ static void MoveImmediateValues(SCEVHandle &Val, SCEVHandle &Imm, // Otherwise, no immediates to move. } +/// RemoveCommonExpressionsFromUseBases - Look through all of the uses in Bases, +/// removing any common subexpressions from it. Anything truly common is +/// removed, accumulated, and returned. This looks for things like (a+b+c) and +/// (a+c+d) -> (a+c). The common expression is *removed* from the Bases. +static SCEVHandle +RemoveCommonExpressionsFromUseBases(std::vector<BasedUser> &Uses) { + unsigned NumUses = Uses.size(); + + // Only one use? Use its base, regardless of what it is! + SCEVHandle Zero = SCEVUnknown::getIntegerSCEV(0, Uses[0].Base->getType()); + SCEVHandle Result = Zero; + if (NumUses == 1) { + std::swap(Result, Uses[0].Base); + return Result; + } + + // To find common subexpressions, count how many of Uses use each expression. + // If any subexpressions are used Uses.size() times, they are common. + std::map<SCEVHandle, unsigned> SubExpressionUseCounts; + + for (unsigned i = 0; i != NumUses; ++i) + if (SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(Uses[i].Base)) { + for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j) + SubExpressionUseCounts[AE->getOperand(j)]++; + } else { + // If the base is zero (which is common), return zero now, there are no + // CSEs we can find. + if (Uses[i].Base == Zero) return Result; + SubExpressionUseCounts[Uses[i].Base]++; + } + + // Now that we know how many times each is used, build Result. + for (std::map<SCEVHandle, unsigned>::iterator I = + SubExpressionUseCounts.begin(), E = SubExpressionUseCounts.end(); + I != E; ) + if (I->second == NumUses) { // Found CSE! + Result = SCEVAddExpr::get(Result, I->first); + ++I; + } else { + // Remove non-cse's from SubExpressionUseCounts. + SubExpressionUseCounts.erase(I++); + } + + // If we found no CSE's, return now. + if (Result == Zero) return Result; + + // Otherwise, remove all of the CSE's we found from each of the base values. + for (unsigned i = 0; i != NumUses; ++i) + if (SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(Uses[i].Base)) { + std::vector<SCEVHandle> NewOps; + + // Remove all of the values that are now in SubExpressionUseCounts. + for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j) + if (!SubExpressionUseCounts.count(AE->getOperand(j))) + NewOps.push_back(AE->getOperand(j)); + Uses[i].Base = SCEVAddExpr::get(NewOps); + } else { + // If the base is zero (which is common), return zero now, there are no + // CSEs we can find. + assert(Uses[i].Base == Result); + Uses[i].Base = Zero; + } + + return Result; +} + + /// StrengthReduceStridedIVUsers - Strength reduce all of the users of a single /// stride of IV. All of the users may have different starting values, and this /// may not be the only stride (we know it is if isOnlyStride is true). @@ -578,25 +646,19 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(Value *Stride, "Base value is not loop invariant!"); } - SCEVExpander Rewriter(*SE, *LI); - SCEVExpander PreheaderRewriter(*SE, *LI); - - BasicBlock *Preheader = L->getLoopPreheader(); - Instruction *PreInsertPt = Preheader->getTerminator(); - Instruction *PhiInsertBefore = L->getHeader()->begin(); - - assert(isa<PHINode>(PhiInsertBefore) && - "How could this loop have IV's without any phis?"); - PHINode *SomeLoopPHI = cast<PHINode>(PhiInsertBefore); - assert(SomeLoopPHI->getNumIncomingValues() == 2 && - "This loop isn't canonicalized right"); - BasicBlock *LatchBlock = - SomeLoopPHI->getIncomingBlock(SomeLoopPHI->getIncomingBlock(0) == Preheader); - - + // We now have a whole bunch of uses of like-strided induction variables, but + // they might all have different bases. We want to emit one PHI node for this + // stride which we fold as many common expressions (between the IVs) into as + // possible. Start by identifying the common expressions in the base values + // for the strides (e.g. if we have "A+C+B" and "A+B+D" as our bases, find + // "A+B"), emit it to the preheader, then remove the expression from the + // UsersToProcess base values. + SCEVHandle CommonExprs = RemoveCommonExpressionsFromUseBases(UsersToProcess); + // Next, figure out what we can represent in the immediate fields of // instructions. If we can represent anything there, move it to the imm - // fields of the BasedUsers. + // fields of the BasedUsers. We do this so that it increases the commonality + // of the remaining uses. for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { // Addressing modes can be folded into loads and stores. Be careful that // the store is through the expression, not of the expression though. @@ -609,59 +671,95 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(Value *Stride, isAddress, L); } + // Now that we know what we need to do, insert the PHI node itself. + // + DEBUG(std::cerr << "INSERTING IV of STRIDE " << *Stride << " and BASE " + << *CommonExprs << " :\n"); + + SCEVExpander Rewriter(*SE, *LI); + SCEVExpander PreheaderRewriter(*SE, *LI); + + BasicBlock *Preheader = L->getLoopPreheader(); + Instruction *PreInsertPt = Preheader->getTerminator(); + Instruction *PhiInsertBefore = L->getHeader()->begin(); + + assert(isa<PHINode>(PhiInsertBefore) && + "How could this loop have IV's without any phis?"); + PHINode *SomeLoopPHI = cast<PHINode>(PhiInsertBefore); + assert(SomeLoopPHI->getNumIncomingValues() == 2 && + "This loop isn't canonicalized right"); + BasicBlock *LatchBlock = + SomeLoopPHI->getIncomingBlock(SomeLoopPHI->getIncomingBlock(0) == Preheader); + // Create a new Phi for this base, and stick it in the loop header. + const Type *ReplacedTy = CommonExprs->getType(); + PHINode *NewPHI = new PHINode(ReplacedTy, "iv.", PhiInsertBefore); + ++NumInserted; - DEBUG(std::cerr << "INSERTING IVs of STRIDE " << *Stride << ":\n"); + // Emit the initial base value into the loop preheader, and add it to the + // Phi node. + Value *PHIBaseV = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt, + ReplacedTy); + NewPHI->addIncoming(PHIBaseV, Preheader); + // Emit the increment of the base value before the terminator of the loop + // latch block, and add it to the Phi node. + SCEVHandle IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), + SCEVUnknown::get(Stride)); + + Value *IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator(), + ReplacedTy); + IncV->setName(NewPHI->getName()+".inc"); + NewPHI->addIncoming(IncV, LatchBlock); + // Sort by the base value, so that all IVs with identical bases are next to - // each other. + // each other. std::sort(UsersToProcess.begin(), UsersToProcess.end()); while (!UsersToProcess.empty()) { SCEVHandle Base = UsersToProcess.front().Base; - DEBUG(std::cerr << " INSERTING PHI with BASE = " << *Base << ":\n"); + DEBUG(std::cerr << " INSERTING code for BASE = " << *Base << ":\n"); - // Create a new Phi for this base, and stick it in the loop header. - const Type *ReplacedTy = Base->getType(); - PHINode *NewPHI = new PHINode(ReplacedTy, "iv.", PhiInsertBefore); - ++NumInserted; - - // Emit the initial base value into the loop preheader, and add it to the - // Phi node. + // Emit the code for Base into the preheader. Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt, ReplacedTy); - NewPHI->addIncoming(BaseV, Preheader); - - // Emit the increment of the base value before the terminator of the loop - // latch block, and add it to the Phi node. - SCEVHandle Inc = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), - SCEVUnknown::get(Stride)); - - Value *IncV = Rewriter.expandCodeFor(Inc, LatchBlock->getTerminator(), - ReplacedTy); - IncV->setName(NewPHI->getName()+".inc"); - NewPHI->addIncoming(IncV, LatchBlock); - + + // If BaseV is a constant other than 0, make sure that it gets inserted into + // the preheader, instead of being forward substituted into the uses. We do + // this by forcing a noop cast to be inserted into the preheader in this + // case. + if (Constant *C = dyn_cast<Constant>(BaseV)) + if (!C->isNullValue()) { + // We want this constant emitted into the preheader! + BaseV = new CastInst(BaseV, BaseV->getType(), "preheaderinsert", + PreInsertPt); + } + // Emit the code to add the immediate offset to the Phi value, just before // the instructions that we identified as using this stride and base. while (!UsersToProcess.empty() && UsersToProcess.front().Base == Base) { BasedUser &User = UsersToProcess.front(); - // Clear the SCEVExpander's expression map so that we are guaranteed - // to have the code emitted where we expect it. - Rewriter.clear(); - - // Now that we know what we need to do, insert code before User for the - // immediate and any loop-variant expressions. - Value *NewBase = NewPHI; - // If this instruction wants to use the post-incremented value, move it // after the post-inc and use its value instead of the PHI. + Value *RewriteOp = NewPHI; if (User.isUseOfPostIncrementedValue) { - NewBase = IncV; + RewriteOp = IncV; User.Inst->moveBefore(LatchBlock->getTerminator()); } - User.RewriteInstructionToUseNewBase(NewBase, Rewriter); + SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp); + + // Clear the SCEVExpander's expression map so that we are guaranteed + // to have the code emitted where we expect it. + Rewriter.clear(); + + // Now that we know what we need to do, insert code before User for the + // immediate and any loop-variant expressions. + if (!isa<ConstantInt>(BaseV) || !cast<ConstantInt>(BaseV)->isNullValue()) + // Add BaseV to the PHI value if needed. + RewriteExpr = SCEVAddExpr::get(RewriteExpr, SCEVUnknown::get(BaseV)); + + User.RewriteInstructionToUseNewBase(RewriteExpr, Rewriter); // Mark old value we replaced as possibly dead, so that it is elminated // if we just replaced the last use of that value. @@ -782,6 +880,8 @@ void LoopStrengthReduce::runOnLoop(Loop *L) { // If we only have one stride, we can more aggressively eliminate some things. bool HasOneStride = IVUsesByStride.size() == 1; + // Note: this processes each stride/type pair individually. All users passed + // into StrengthReduceStridedIVUsers have the same type AND stride. for (std::map<Value*, IVUsersOfOneStride>::iterator SI = IVUsesByStride.begin(), E = IVUsesByStride.end(); SI != E; ++SI) StrengthReduceStridedIVUsers(SI->first, SI->second, L, HasOneStride); |