aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/LoopStrengthReduce.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/LoopStrengthReduce.cpp')
-rw-r--r--lib/Transforms/Scalar/LoopStrengthReduce.cpp206
1 files changed, 153 insertions, 53 deletions
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 15b1ee04b0..e20f4be998 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -382,7 +382,8 @@ namespace {
// Once we rewrite the code to insert the new IVs we want, update the
// operands of Inst to use the new expression 'NewBase', with 'Imm' added
// to it.
- void RewriteInstructionToUseNewBase(Value *NewBase, SCEVExpander &Rewriter);
+ void RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
+ SCEVExpander &Rewriter);
// Sort by the Base field.
bool operator<(const BasedUser &BU) const { return Base < BU.Base; }
@@ -403,10 +404,10 @@ void BasedUser::dump() const {
// Once we rewrite the code to insert the new IVs we want, update the
// operands of Inst to use the new expression 'NewBase', with 'Imm' added
// to it.
-void BasedUser::RewriteInstructionToUseNewBase(Value *NewBase,
+void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
SCEVExpander &Rewriter) {
if (!isa<PHINode>(Inst)) {
- SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(NewBase), Imm);
+ SCEVHandle NewValSCEV = SCEVAddExpr::get(NewBase, Imm);
Value *NewVal = Rewriter.expandCodeFor(NewValSCEV, Inst,
OperandValToReplace->getType());
@@ -426,7 +427,7 @@ void BasedUser::RewriteInstructionToUseNewBase(Value *NewBase,
// Insert the code into the end of the predecessor block.
BasicBlock::iterator InsertPt = PN->getIncomingBlock(i)->getTerminator();
- SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(NewBase), Imm);
+ SCEVHandle NewValSCEV = SCEVAddExpr::get(NewBase, Imm);
Value *NewVal = Rewriter.expandCodeFor(NewValSCEV, InsertPt,
OperandValToReplace->getType());
@@ -552,6 +553,73 @@ static void MoveImmediateValues(SCEVHandle &Val, SCEVHandle &Imm,
// Otherwise, no immediates to move.
}
+/// RemoveCommonExpressionsFromUseBases - Look through all of the uses in Bases,
+/// removing any common subexpressions from it. Anything truly common is
+/// removed, accumulated, and returned. This looks for things like (a+b+c) and
+/// (a+c+d) -> (a+c). The common expression is *removed* from the Bases.
+static SCEVHandle
+RemoveCommonExpressionsFromUseBases(std::vector<BasedUser> &Uses) {
+ unsigned NumUses = Uses.size();
+
+ // Only one use? Use its base, regardless of what it is!
+ SCEVHandle Zero = SCEVUnknown::getIntegerSCEV(0, Uses[0].Base->getType());
+ SCEVHandle Result = Zero;
+ if (NumUses == 1) {
+ std::swap(Result, Uses[0].Base);
+ return Result;
+ }
+
+ // To find common subexpressions, count how many of Uses use each expression.
+ // If any subexpressions are used Uses.size() times, they are common.
+ std::map<SCEVHandle, unsigned> SubExpressionUseCounts;
+
+ for (unsigned i = 0; i != NumUses; ++i)
+ if (SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(Uses[i].Base)) {
+ for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j)
+ SubExpressionUseCounts[AE->getOperand(j)]++;
+ } else {
+ // If the base is zero (which is common), return zero now, there are no
+ // CSEs we can find.
+ if (Uses[i].Base == Zero) return Result;
+ SubExpressionUseCounts[Uses[i].Base]++;
+ }
+
+ // Now that we know how many times each is used, build Result.
+ for (std::map<SCEVHandle, unsigned>::iterator I =
+ SubExpressionUseCounts.begin(), E = SubExpressionUseCounts.end();
+ I != E; )
+ if (I->second == NumUses) { // Found CSE!
+ Result = SCEVAddExpr::get(Result, I->first);
+ ++I;
+ } else {
+ // Remove non-cse's from SubExpressionUseCounts.
+ SubExpressionUseCounts.erase(I++);
+ }
+
+ // If we found no CSE's, return now.
+ if (Result == Zero) return Result;
+
+ // Otherwise, remove all of the CSE's we found from each of the base values.
+ for (unsigned i = 0; i != NumUses; ++i)
+ if (SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(Uses[i].Base)) {
+ std::vector<SCEVHandle> NewOps;
+
+ // Remove all of the values that are now in SubExpressionUseCounts.
+ for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j)
+ if (!SubExpressionUseCounts.count(AE->getOperand(j)))
+ NewOps.push_back(AE->getOperand(j));
+ Uses[i].Base = SCEVAddExpr::get(NewOps);
+ } else {
+ // If the base is zero (which is common), return zero now, there are no
+ // CSEs we can find.
+ assert(Uses[i].Base == Result);
+ Uses[i].Base = Zero;
+ }
+
+ return Result;
+}
+
+
/// StrengthReduceStridedIVUsers - Strength reduce all of the users of a single
/// stride of IV. All of the users may have different starting values, and this
/// may not be the only stride (we know it is if isOnlyStride is true).
@@ -578,25 +646,19 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(Value *Stride,
"Base value is not loop invariant!");
}
- SCEVExpander Rewriter(*SE, *LI);
- SCEVExpander PreheaderRewriter(*SE, *LI);
-
- BasicBlock *Preheader = L->getLoopPreheader();
- Instruction *PreInsertPt = Preheader->getTerminator();
- Instruction *PhiInsertBefore = L->getHeader()->begin();
-
- assert(isa<PHINode>(PhiInsertBefore) &&
- "How could this loop have IV's without any phis?");
- PHINode *SomeLoopPHI = cast<PHINode>(PhiInsertBefore);
- assert(SomeLoopPHI->getNumIncomingValues() == 2 &&
- "This loop isn't canonicalized right");
- BasicBlock *LatchBlock =
- SomeLoopPHI->getIncomingBlock(SomeLoopPHI->getIncomingBlock(0) == Preheader);
-
-
+ // We now have a whole bunch of uses of like-strided induction variables, but
+ // they might all have different bases. We want to emit one PHI node for this
+ // stride which we fold as many common expressions (between the IVs) into as
+ // possible. Start by identifying the common expressions in the base values
+ // for the strides (e.g. if we have "A+C+B" and "A+B+D" as our bases, find
+ // "A+B"), emit it to the preheader, then remove the expression from the
+ // UsersToProcess base values.
+ SCEVHandle CommonExprs = RemoveCommonExpressionsFromUseBases(UsersToProcess);
+
// Next, figure out what we can represent in the immediate fields of
// instructions. If we can represent anything there, move it to the imm
- // fields of the BasedUsers.
+ // fields of the BasedUsers. We do this so that it increases the commonality
+ // of the remaining uses.
for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) {
// Addressing modes can be folded into loads and stores. Be careful that
// the store is through the expression, not of the expression though.
@@ -609,59 +671,95 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(Value *Stride,
isAddress, L);
}
+ // Now that we know what we need to do, insert the PHI node itself.
+ //
+ DEBUG(std::cerr << "INSERTING IV of STRIDE " << *Stride << " and BASE "
+ << *CommonExprs << " :\n");
+
+ SCEVExpander Rewriter(*SE, *LI);
+ SCEVExpander PreheaderRewriter(*SE, *LI);
+
+ BasicBlock *Preheader = L->getLoopPreheader();
+ Instruction *PreInsertPt = Preheader->getTerminator();
+ Instruction *PhiInsertBefore = L->getHeader()->begin();
+
+ assert(isa<PHINode>(PhiInsertBefore) &&
+ "How could this loop have IV's without any phis?");
+ PHINode *SomeLoopPHI = cast<PHINode>(PhiInsertBefore);
+ assert(SomeLoopPHI->getNumIncomingValues() == 2 &&
+ "This loop isn't canonicalized right");
+ BasicBlock *LatchBlock =
+ SomeLoopPHI->getIncomingBlock(SomeLoopPHI->getIncomingBlock(0) == Preheader);
+ // Create a new Phi for this base, and stick it in the loop header.
+ const Type *ReplacedTy = CommonExprs->getType();
+ PHINode *NewPHI = new PHINode(ReplacedTy, "iv.", PhiInsertBefore);
+ ++NumInserted;
- DEBUG(std::cerr << "INSERTING IVs of STRIDE " << *Stride << ":\n");
+ // Emit the initial base value into the loop preheader, and add it to the
+ // Phi node.
+ Value *PHIBaseV = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt,
+ ReplacedTy);
+ NewPHI->addIncoming(PHIBaseV, Preheader);
+ // Emit the increment of the base value before the terminator of the loop
+ // latch block, and add it to the Phi node.
+ SCEVHandle IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI),
+ SCEVUnknown::get(Stride));
+
+ Value *IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator(),
+ ReplacedTy);
+ IncV->setName(NewPHI->getName()+".inc");
+ NewPHI->addIncoming(IncV, LatchBlock);
+
// Sort by the base value, so that all IVs with identical bases are next to
- // each other.
+ // each other.
std::sort(UsersToProcess.begin(), UsersToProcess.end());
while (!UsersToProcess.empty()) {
SCEVHandle Base = UsersToProcess.front().Base;
- DEBUG(std::cerr << " INSERTING PHI with BASE = " << *Base << ":\n");
+ DEBUG(std::cerr << " INSERTING code for BASE = " << *Base << ":\n");
- // Create a new Phi for this base, and stick it in the loop header.
- const Type *ReplacedTy = Base->getType();
- PHINode *NewPHI = new PHINode(ReplacedTy, "iv.", PhiInsertBefore);
- ++NumInserted;
-
- // Emit the initial base value into the loop preheader, and add it to the
- // Phi node.
+ // Emit the code for Base into the preheader.
Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt,
ReplacedTy);
- NewPHI->addIncoming(BaseV, Preheader);
-
- // Emit the increment of the base value before the terminator of the loop
- // latch block, and add it to the Phi node.
- SCEVHandle Inc = SCEVAddExpr::get(SCEVUnknown::get(NewPHI),
- SCEVUnknown::get(Stride));
-
- Value *IncV = Rewriter.expandCodeFor(Inc, LatchBlock->getTerminator(),
- ReplacedTy);
- IncV->setName(NewPHI->getName()+".inc");
- NewPHI->addIncoming(IncV, LatchBlock);
-
+
+ // If BaseV is a constant other than 0, make sure that it gets inserted into
+ // the preheader, instead of being forward substituted into the uses. We do
+ // this by forcing a noop cast to be inserted into the preheader in this
+ // case.
+ if (Constant *C = dyn_cast<Constant>(BaseV))
+ if (!C->isNullValue()) {
+ // We want this constant emitted into the preheader!
+ BaseV = new CastInst(BaseV, BaseV->getType(), "preheaderinsert",
+ PreInsertPt);
+ }
+
// Emit the code to add the immediate offset to the Phi value, just before
// the instructions that we identified as using this stride and base.
while (!UsersToProcess.empty() && UsersToProcess.front().Base == Base) {
BasedUser &User = UsersToProcess.front();
- // Clear the SCEVExpander's expression map so that we are guaranteed
- // to have the code emitted where we expect it.
- Rewriter.clear();
-
- // Now that we know what we need to do, insert code before User for the
- // immediate and any loop-variant expressions.
- Value *NewBase = NewPHI;
-
// If this instruction wants to use the post-incremented value, move it
// after the post-inc and use its value instead of the PHI.
+ Value *RewriteOp = NewPHI;
if (User.isUseOfPostIncrementedValue) {
- NewBase = IncV;
+ RewriteOp = IncV;
User.Inst->moveBefore(LatchBlock->getTerminator());
}
- User.RewriteInstructionToUseNewBase(NewBase, Rewriter);
+ SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp);
+
+ // Clear the SCEVExpander's expression map so that we are guaranteed
+ // to have the code emitted where we expect it.
+ Rewriter.clear();
+
+ // Now that we know what we need to do, insert code before User for the
+ // immediate and any loop-variant expressions.
+ if (!isa<ConstantInt>(BaseV) || !cast<ConstantInt>(BaseV)->isNullValue())
+ // Add BaseV to the PHI value if needed.
+ RewriteExpr = SCEVAddExpr::get(RewriteExpr, SCEVUnknown::get(BaseV));
+
+ User.RewriteInstructionToUseNewBase(RewriteExpr, Rewriter);
// Mark old value we replaced as possibly dead, so that it is elminated
// if we just replaced the last use of that value.
@@ -782,6 +880,8 @@ void LoopStrengthReduce::runOnLoop(Loop *L) {
// If we only have one stride, we can more aggressively eliminate some things.
bool HasOneStride = IVUsesByStride.size() == 1;
+ // Note: this processes each stride/type pair individually. All users passed
+ // into StrengthReduceStridedIVUsers have the same type AND stride.
for (std::map<Value*, IVUsersOfOneStride>::iterator SI
= IVUsesByStride.begin(), E = IVUsesByStride.end(); SI != E; ++SI)
StrengthReduceStridedIVUsers(SI->first, SI->second, L, HasOneStride);