diff options
author | Chris Lattner <sabre@nondot.org> | 2006-03-04 09:31:13 +0000 |
---|---|---|
committer | Chris Lattner <sabre@nondot.org> | 2006-03-04 09:31:13 +0000 |
commit | e5022fe4cd83eef91f5c3a21c943ca9b65507ab8 (patch) | |
tree | e4473c41da8e12b580dc1e60049df94e4b2c2687 /lib/Transforms/Scalar/Reassociate.cpp | |
parent | ad01993194af59c68f8507528a09fee45cde8f24 (diff) |
Add factoring of multiplications, e.g. turning A*A+A*B into A*(A+B).
Testcase here: Transforms/Reassociate/mulfactor.ll
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@26524 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r-- | lib/Transforms/Scalar/Reassociate.cpp | 235 |
1 files changed, 186 insertions, 49 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index 41faae7496..61c5c4953c 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -41,6 +41,7 @@ namespace { Statistic<> NumChanged("reassociate","Number of insts reassociated"); Statistic<> NumSwapped("reassociate","Number of insts with operands swapped"); Statistic<> NumAnnihil("reassociate","Number of expr tree annihilated"); + Statistic<> NumFactor ("reassociate","Number of multiplies factored"); struct ValueEntry { unsigned Rank; @@ -50,7 +51,20 @@ namespace { inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. } +} +/// PrintOps - Print out the expression identified in the Ops list. +/// +static void PrintOps(Instruction *I, const std::vector<ValueEntry> &Ops) { + Module *M = I->getParent()->getParent()->getParent(); + std::cerr << Instruction::getOpcodeName(I->getOpcode()) << " " + << *Ops[0].Op->getType(); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M) + << "," << Ops[i].Rank; +} + +namespace { class Reassociate : public FunctionPass { std::map<BasicBlock*, unsigned> RankMap; std::map<Value*, unsigned> ValueRankMap; @@ -66,10 +80,13 @@ namespace { unsigned getRank(Value *V); void RewriteExprTree(BinaryOperator *I, unsigned Idx, std::vector<ValueEntry> &Ops); - void OptimizeExpression(unsigned Opcode, std::vector<ValueEntry> &Ops); + Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops); void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops); void LinearizeExpr(BinaryOperator *I); + Value *RemoveFactorFromExpression(Value *V, Value *Factor); void ReassociateBB(BasicBlock *BB); + + void RemoveDeadBinaryOp(Value *V); }; RegisterOpt<Reassociate> X("reassociate", "Reassociate expressions"); @@ -78,6 +95,15 @@ namespace { // Public interface to the Reassociate pass FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } +void Reassociate::RemoveDeadBinaryOp(Value *V) { + BinaryOperator *BOp = dyn_cast<BinaryOperator>(V); + if (!BOp || !BOp->use_empty()) return; + + Value *LHS = BOp->getOperand(0), *RHS = BOp->getOperand(1); + RemoveDeadBinaryOp(LHS); + RemoveDeadBinaryOp(RHS); +} + static bool isUnmovableInstruction(Instruction *I) { if (I->getOpcode() == Instruction::PHI || @@ -207,9 +233,6 @@ void Reassociate::LinearizeExpr(BinaryOperator *I) { /// form of the the expression (((a+b)+c)+d), and collects information about the /// rank of the non-tree operands. /// -/// This returns the rank of the RHS operand, which is known to be the highest -/// rank value in the expression tree. -/// void Reassociate::LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); @@ -279,12 +302,17 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i, if (i+2 == Ops.size()) { if (I->getOperand(0) != Ops[i].Op || I->getOperand(1) != Ops[i+1].Op) { + Value *OldLHS = I->getOperand(0); DEBUG(std::cerr << "RA: " << *I); I->setOperand(0, Ops[i].Op); I->setOperand(1, Ops[i+1].Op); DEBUG(std::cerr << "TO: " << *I); MadeChange = true; ++NumChanged; + + // If we reassociated a tree to fewer operands (e.g. (1+a+2) -> (a+3) + // delete the extra, now dead, nodes. + RemoveDeadBinaryOp(OldLHS); } return; } @@ -297,7 +325,15 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i, MadeChange = true; ++NumChanged; } - RewriteExprTree(cast<BinaryOperator>(I->getOperand(0)), i+1, Ops); + + BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0)); + assert(LHS->getOpcode() == I->getOpcode() && + "Improper expression tree!"); + + // Compactify the tree instructions together with each other to guarantee + // that the expression tree is dominated by all of Ops. + LHS->moveBefore(I); + RewriteExprTree(LHS, i+1, Ops); } @@ -405,19 +441,57 @@ static unsigned FindInOperandList(std::vector<ValueEntry> &Ops, unsigned i, return i; } -void Reassociate::OptimizeExpression(unsigned Opcode, - std::vector<ValueEntry> &Ops) { +/// EmitAddTreeOfValues - Emit a tree of add instructions, summing Ops together +/// and returning the result. Insert the tree before I. +static Value *EmitAddTreeOfValues(Instruction *I, std::vector<Value*> &Ops) { + if (Ops.size() == 1) return Ops.back(); + + Value *V1 = Ops.back(); + Ops.pop_back(); + Value *V2 = EmitAddTreeOfValues(I, Ops); + return BinaryOperator::createAdd(V2, V1, "tmp", I); +} + +/// RemoveFactorFromExpression - If V is an expression tree that is a +/// multiplication sequence, and if this sequence contains a multiply by Factor, +/// remove Factor from the tree and return the new tree. +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); + if (!BO) return 0; + + std::vector<ValueEntry> Factors; + LinearizeExprTree(BO, Factors); + + bool FoundFactor = false; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) + if (Factors[i].Op == Factor) { + FoundFactor = true; + Factors.erase(Factors.begin()+i); + break; + } + if (!FoundFactor) return 0; + + if (Factors.size() == 1) return Factors[0].Op; + + RewriteExprTree(BO, 0, Factors); + return BO; +} + + +Value *Reassociate::OptimizeExpression(BinaryOperator *I, + std::vector<ValueEntry> &Ops) { // Now that we have the linearized expression tree, try to optimize it. // Start by folding any constants that we found. bool IterateOptimization = false; - if (Ops.size() == 1) return; + if (Ops.size() == 1) return Ops[0].Op; + unsigned Opcode = I->getOpcode(); + if (Constant *V1 = dyn_cast<Constant>(Ops[Ops.size()-2].Op)) if (Constant *V2 = dyn_cast<Constant>(Ops.back().Op)) { Ops.pop_back(); Ops.back().Op = ConstantExpr::get(Opcode, V1, V2); - OptimizeExpression(Opcode, Ops); - return; + return OptimizeExpression(I, Ops); } // Check for destructive annihilation due to a constant being used. @@ -426,30 +500,24 @@ void Reassociate::OptimizeExpression(unsigned Opcode, default: break; case Instruction::And: if (CstVal->isNullValue()) { // ... & 0 -> 0 - Ops[0].Op = CstVal; - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return CstVal; } else if (CstVal->isAllOnesValue()) { // ... & -1 -> ... Ops.pop_back(); } break; case Instruction::Mul: if (CstVal->isNullValue()) { // ... * 0 -> 0 - Ops[0].Op = CstVal; - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return CstVal; } else if (cast<ConstantInt>(CstVal)->getRawValue() == 1) { Ops.pop_back(); // ... * 1 -> ... } break; case Instruction::Or: if (CstVal->isAllOnesValue()) { // ... | -1 -> -1 - Ops[0].Op = CstVal; - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return CstVal; } // FALLTHROUGH! case Instruction::Add: @@ -458,7 +526,7 @@ void Reassociate::OptimizeExpression(unsigned Opcode, Ops.pop_back(); break; } - if (Ops.size() == 1) return; + if (Ops.size() == 1) return Ops[0].Op; // Handle destructive annihilation do to identities between elements in the // argument list here. @@ -477,15 +545,11 @@ void Reassociate::OptimizeExpression(unsigned Opcode, unsigned FoundX = FindInOperandList(Ops, i, X); if (FoundX != i) { if (Opcode == Instruction::And) { // ...&X&~X = 0 - Ops[0].Op = Constant::getNullValue(X->getType()); - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return Constant::getNullValue(X->getType()); } else if (Opcode == Instruction::Or) { // ...|X|~X = -1 - Ops[0].Op = ConstantIntegral::getAllOnesValue(X->getType()); - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return ConstantIntegral::getAllOnesValue(X->getType()); } } } @@ -503,10 +567,8 @@ void Reassociate::OptimizeExpression(unsigned Opcode, } else { assert(Opcode == Instruction::Xor); if (e == 2) { - Ops[0].Op = Constant::getNullValue(Ops[0].Op->getType()); - Ops.erase(Ops.begin()+1, Ops.end()); ++NumAnnihil; - return; + return Constant::getNullValue(Ops[0].Op->getType()); } // ... X^X -> ... Ops.erase(Ops.begin()+i, Ops.begin()+i+2); @@ -520,7 +582,7 @@ void Reassociate::OptimizeExpression(unsigned Opcode, case Instruction::Add: // Scan the operand lists looking for X and -X pairs. If we find any, we - // can simplify the expression. X+-X == 0 + // can simplify the expression. X+-X == 0. for (unsigned i = 0, e = Ops.size(); i != e; ++i) { assert(i < Ops.size()); // Check for X and -X in the operand list. @@ -530,10 +592,8 @@ void Reassociate::OptimizeExpression(unsigned Opcode, if (FoundX != i) { // Remove X and -X from the operand list. if (Ops.size() == 2) { - Ops[0].Op = Constant::getNullValue(X->getType()); - Ops.pop_back(); ++NumAnnihil; - return; + return Constant::getNullValue(X->getType()); } else { Ops.erase(Ops.begin()+i); if (i < FoundX) @@ -549,30 +609,99 @@ void Reassociate::OptimizeExpression(unsigned Opcode, } } } + + + // Scan the operand list, checking to see if there are any common factors + // between operands. Consider something like A*A+A*B*C+D. We would like to + // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. + // To efficiently find this, we count the number of times a factor occurs + // for any ADD operands that are MULs. + std::map<Value*, unsigned> FactorOccurrences; + unsigned MaxOcc = 0; + Value *MaxOccVal = 0; + if (!I->getType()->isFloatingPoint()) { + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op)) + if (BOp->getOpcode() == Instruction::Mul && BOp->hasOneUse()) { + // Compute all of the factors of this added value. + std::vector<ValueEntry> Factors; + LinearizeExprTree(BOp, Factors); + assert(Factors.size() > 1 && "Bad linearize!"); + + // Add one to FactorOccurrences for each unique factor in this op. + if (Factors.size() == 2) { + unsigned Occ = ++FactorOccurrences[Factors[0].Op]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0].Op; } + if (Factors[0].Op != Factors[1].Op) { // Don't double count A*A. + Occ = ++FactorOccurrences[Factors[1].Op]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1].Op; } + } + } else { + std::set<Value*> Duplicates; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) + if (Duplicates.insert(Factors[i].Op).second) { + unsigned Occ = ++FactorOccurrences[Factors[i].Op]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i].Op; } + } + } + } + } + } + + // If any factor occurred more than one time, we can pull it out. + if (MaxOcc > 1) { + DEBUG(std::cerr << "\nFACTORING [" << MaxOcc << "]: " + << *MaxOccVal << "\n"); + + // Create a new instruction that uses the MaxOccVal twice. If we don't do + // this, we could otherwise run into situations where removing a factor + // from an expression will drop a use of maxocc, and this can cause + // RemoveFactorFromExpression on successive values to behave differently. + Instruction *DummyInst = BinaryOperator::createAdd(MaxOccVal, MaxOccVal); + std::vector<Value*> NewMulOps; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) { + NewMulOps.push_back(V); + Ops.erase(Ops.begin()+i); + --i; --e; + } + } + + // No need for extra uses anymore. + delete DummyInst; + + Value *V = EmitAddTreeOfValues(I, NewMulOps); + // FIXME: Must optimize V now, to handle this case: + // A*A*B + A*A*C -> A*(A*B+A*C) -> A*(A*(B+C)) + V = BinaryOperator::createMul(V, MaxOccVal, "tmp", I); + + ++NumFactor; + + if (Ops.size() == 0) + return V; + + // Add the new value to the list of things being added. + Ops.insert(Ops.begin(), ValueEntry(getRank(V), V)); + + // Rewrite the tree so that there is now a use of V. + RewriteExprTree(I, 0, Ops); + return OptimizeExpression(I, Ops); + } break; //case Instruction::Mul: } if (IterateOptimization) - OptimizeExpression(Opcode, Ops); + return OptimizeExpression(I, Ops); + return 0; } -/// PrintOps - Print out the expression identified in the Ops list. -/// -static void PrintOps(unsigned Opcode, const std::vector<ValueEntry> &Ops, - BasicBlock *BB) { - Module *M = BB->getParent()->getParent(); - std::cerr << Instruction::getOpcodeName(Opcode) << " " - << *Ops[0].Op->getType(); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M) - << "," << Ops[i].Rank; -} /// ReassociateBB - Inspect all of the instructions in this basic block, /// reassociating them as we go. void Reassociate::ReassociateBB(BasicBlock *BB) { - for (BasicBlock::iterator BI = BB->begin(); BI != BB->end(); ++BI) { + for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) { + Instruction *BI = BBI++; if (BI->getOpcode() == Instruction::Shl && isa<ConstantInt>(BI->getOperand(1))) if (Instruction *NI = ConvertShiftToMul(BI)) { @@ -623,7 +752,7 @@ void Reassociate::ReassociateBB(BasicBlock *BB) { std::vector<ValueEntry> Ops; LinearizeExprTree(I, Ops); - DEBUG(std::cerr << "RAIn:\t"; PrintOps(I->getOpcode(), Ops, BB); + DEBUG(std::cerr << "RAIn:\t"; PrintOps(I, Ops); std::cerr << "\n"); // Now that we have linearized the tree to a list and have gathered all of @@ -636,7 +765,14 @@ void Reassociate::ReassociateBB(BasicBlock *BB) { // OptimizeExpression - Now that we have the expression tree in a convenient // sorted form, optimize it globally if possible. - OptimizeExpression(I->getOpcode(), Ops); + if (Value *V = OptimizeExpression(I, Ops)) { + // This expression tree simplified to something that isn't a tree, + // eliminate it. + DEBUG(std::cerr << "Reassoc to scalar: " << *V << "\n"); + I->replaceAllUsesWith(V); + RemoveDeadBinaryOp(I); + continue; + } // We want to sink immediates as deeply as possible except in the case where // this is a multiply tree used only by an add, and the immediate is a -1. @@ -650,13 +786,14 @@ void Reassociate::ReassociateBB(BasicBlock *BB) { Ops.pop_back(); } - DEBUG(std::cerr << "RAOut:\t"; PrintOps(I->getOpcode(), Ops, BB); + DEBUG(std::cerr << "RAOut:\t"; PrintOps(I, Ops); std::cerr << "\n"); if (Ops.size() == 1) { // This expression tree simplified to something that isn't a tree, // eliminate it. I->replaceAllUsesWith(Ops[0].Op); + RemoveDeadBinaryOp(I); } else { // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. |