From 99650c9088c5dd4b6788a99b63c82d13e0518961 Mon Sep 17 00:00:00 2001 From: Chandler Carruth Date: Fri, 4 May 2012 10:18:49 +0000 Subject: Move the CodeExtractor utility to a dedicated header file / source file, and expose it as a utility class rather than as free function wrappers. The simple free-function interface works well for the bugpoint-specific pass's uses of code extraction, but in an upcoming patch for more advanced code extraction, they simply don't expose a rich enough interface. I need to expose various stages of the process of doing the code extraction and query information to decide whether or not to actually complete the extraction or give up. Rather than build up a new predicate model and pass that into these functions, just take the class that was actually implementing the functions and lift it up into a proper interface that can be used to perform code extraction. The interface is cleaned up and re-documented to work better in a header. It also is now setup to accept the blocks to be extracted in the constructor rather than in a method. In passing this essentially reverts my previous commit here exposing a block-level query for eligibility of extraction. That is no longer necessary with the more rich interface as clients can query the extraction object for eligibility directly. This will reduce the number of walks of the input basic block sequence by quite a bit which is useful if this enters the normal optimization pipeline. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@156163 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Utils/CodeExtractor.cpp | 268 +++++++++++++-------------------- 1 file changed, 107 insertions(+), 161 deletions(-) (limited to 'lib/Transforms/Utils/CodeExtractor.cpp') diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index b8cea45178..50eb8a27e0 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/FunctionUtils.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" #include "llvm/Instructions.h" @@ -43,61 +43,78 @@ static cl::opt AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, cl::desc("Aggregate arguments to code-extracted functions")); -namespace { - class CodeExtractor { - typedef SetVector Values; - SetVector BlocksToExtract; - DominatorTree* DT; - bool AggregateArgs; - unsigned NumExitBlocks; - Type *RetTy; - public: - CodeExtractor(DominatorTree* dt = 0, bool AggArgs = false) - : DT(dt), AggregateArgs(AggArgs||AggregateArgsOpt), NumExitBlocks(~0U) {} - - Function *ExtractCodeRegion(ArrayRef code); - - bool isEligible(ArrayRef code); - - private: - /// definedInRegion - Return true if the specified value is defined in the - /// extracted region. - bool definedInRegion(Value *V) const { - if (Instruction *I = dyn_cast(V)) - if (BlocksToExtract.count(I->getParent())) - return true; - return false; - } +/// \brief Test whether a block is valid for extraction. +static bool isBlockValidForExtraction(const BasicBlock &BB) { + // Landing pads must be in the function where they were inserted for cleanup. + if (BB.isLandingPad()) + return false; - /// definedInCaller - Return true if the specified value is defined in the - /// function being code extracted, but not in the region being extracted. - /// These values must be passed in as live-ins to the function. - bool definedInCaller(Value *V) const { - if (isa(V)) return true; - if (Instruction *I = dyn_cast(V)) - if (!BlocksToExtract.count(I->getParent())) - return true; + // Don't hoist code containing allocas, invokes, or vastarts. + for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { + if (isa(I) || isa(I)) return false; + if (const CallInst *CI = dyn_cast(I)) + if (const Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::vastart) + return false; + } + + return true; +} + +/// \brief Build a set of blocks to extract if the input blocks are viable. +static SetVector +buildExtractionBlockSet(ArrayRef BBs) { + SetVector Result; + + // Loop over the blocks, adding them to our set-vector, and aborting with an + // empty set if we encounter invalid blocks. + for (ArrayRef::iterator I = BBs.begin(), E = BBs.end(); + I != E; ++I) { + if (!Result.insert(*I)) + continue; + + if (!isBlockValidForExtraction(**I)) { + Result.clear(); + break; } + } - void severSplitPHINodes(BasicBlock *&Header); - void splitReturnBlocks(); - void findInputsOutputs(Values &inputs, Values &outputs); + return Result; +} + +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) + : DT(0), AggregateArgs(AggregateArgs||AggregateArgsOpt), + Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} - Function *constructFunction(const Values &inputs, - const Values &outputs, - BasicBlock *header, - BasicBlock *newRootNode, BasicBlock *newHeader, - Function *oldFunction, Module *M); +CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, + bool AggregateArgs) + : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - void moveCodeToFunction(Function *newFunction); +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} - void emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *newHeader, - Values &inputs, - Values &outputs); - }; +/// definedInRegion - Return true if the specified value is defined in the +/// extracted region. +static bool definedInRegion(const SetVector &Blocks, Value *V) { + if (Instruction *I = dyn_cast(V)) + if (Blocks.count(I->getParent())) + return true; + return false; +} + +/// definedInCaller - Return true if the specified value is defined in the +/// function being code extracted, but not in the region being extracted. +/// These values must be passed in as live-ins to the function. +static bool definedInCaller(const SetVector &Blocks, Value *V) { + if (isa(V)) return true; + if (Instruction *I = dyn_cast(V)) + if (!Blocks.count(I->getParent())) + return true; + return false; } /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the @@ -115,7 +132,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // than one entry from outside the region. If so, we need to sever the // header block into two. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (BlocksToExtract.count(PN->getIncomingBlock(i))) + if (Blocks.count(PN->getIncomingBlock(i))) ++NumPredsFromRegion; else ++NumPredsOutsideRegion; @@ -136,8 +153,8 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // We only want to code extract the second block now, and it becomes the new // header of the region. BasicBlock *OldPred = Header; - BlocksToExtract.remove(OldPred); - BlocksToExtract.insert(NewBB); + Blocks.remove(OldPred); + Blocks.insert(NewBB); Header = NewBB; // Okay, update dominator sets. The blocks that dominate the new one are the @@ -152,7 +169,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // Loop over all of the predecessors of OldPred that are in the region, // changing them to branch to NewBB instead. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (BlocksToExtract.count(PN->getIncomingBlock(i))) { + if (Blocks.count(PN->getIncomingBlock(i))) { TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); TI->replaceUsesOfWith(OldPred, NewBB); } @@ -170,7 +187,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // Loop over all of the incoming value in PN, moving them to NewPN if they // are from the extracted region. for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { - if (BlocksToExtract.count(PN->getIncomingBlock(i))) { + if (Blocks.count(PN->getIncomingBlock(i))) { NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); PN->removeIncomingValue(i); --i; @@ -181,8 +198,8 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { } void CodeExtractor::splitReturnBlocks() { - for (SetVector::iterator I = BlocksToExtract.begin(), - E = BlocksToExtract.end(); I != E; ++I) + for (SetVector::iterator I = Blocks.begin(), E = Blocks.end(); + I != E; ++I) if (ReturnInst *RI = dyn_cast((*I)->getTerminator())) { BasicBlock *New = (*I)->splitBasicBlock(RI, (*I)->getName()+".ret"); if (DT) { @@ -205,23 +222,23 @@ void CodeExtractor::splitReturnBlocks() { // findInputsOutputs - Find inputs to, outputs from the code region. // -void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { +void CodeExtractor::findInputsOutputs(ValueSet &inputs, ValueSet &outputs) { std::set ExitBlocks; - for (SetVector::const_iterator ci = BlocksToExtract.begin(), - ce = BlocksToExtract.end(); ci != ce; ++ci) { + for (SetVector::const_iterator ci = Blocks.begin(), + ce = Blocks.end(); ci != ce; ++ci) { BasicBlock *BB = *ci; for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) - if (definedInCaller(*O)) + if (definedInCaller(Blocks, *O)) inputs.insert(*O); // Consider uses of this instruction (outputs). for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; ++UI) - if (!definedInRegion(*UI)) { + if (!definedInRegion(Blocks, *UI)) { outputs.insert(I); break; } @@ -230,7 +247,7 @@ void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { // Keep track of the exit blocks from the region. TerminatorInst *TI = BB->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (!BlocksToExtract.count(TI->getSuccessor(i))) + if (!Blocks.count(TI->getSuccessor(i))) ExitBlocks.insert(TI->getSuccessor(i)); } // for: basic blocks @@ -240,8 +257,8 @@ void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { /// constructFunction - make a function based on inputs and outputs, as follows: /// f(in0, ..., inN, out0, ..., outN) /// -Function *CodeExtractor::constructFunction(const Values &inputs, - const Values &outputs, +Function *CodeExtractor::constructFunction(const ValueSet &inputs, + const ValueSet &outputs, BasicBlock *header, BasicBlock *newRootNode, BasicBlock *newHeader, @@ -261,15 +278,15 @@ Function *CodeExtractor::constructFunction(const Values &inputs, std::vector paramTy; // Add the types of the input values to the function's argument list - for (Values::const_iterator i = inputs.begin(), - e = inputs.end(); i != e; ++i) { + for (ValueSet::const_iterator i = inputs.begin(), e = inputs.end(); + i != e; ++i) { const Value *value = *i; DEBUG(dbgs() << "value used in func: " << *value << "\n"); paramTy.push_back(value->getType()); } // Add the types of the output values to the function's argument list. - for (Values::const_iterator I = outputs.begin(), E = outputs.end(); + for (ValueSet::const_iterator I = outputs.begin(), E = outputs.end(); I != E; ++I) { DEBUG(dbgs() << "instr used in func: " << **I << "\n"); if (AggregateArgs) @@ -326,7 +343,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, for (std::vector::iterator use = Users.begin(), useE = Users.end(); use != useE; ++use) if (Instruction* inst = dyn_cast(*use)) - if (BlocksToExtract.count(inst->getParent())) + if (Blocks.count(inst->getParent())) inst->replaceUsesOfWith(inputs[i], RewriteVal); } @@ -347,7 +364,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block if (TerminatorInst *TI = dyn_cast(Users[i])) - if (!BlocksToExtract.count(TI->getParent()) && + if (!Blocks.count(TI->getParent()) && TI->getParent()->getParent() == oldFunction) TI->replaceUsesOfWith(header, newHeader); @@ -373,7 +390,7 @@ static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { /// necessary. void CodeExtractor:: emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, - Values &inputs, Values &outputs) { + ValueSet &inputs, ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector params, StructValues, ReloadOutputs, Reloads; @@ -381,14 +398,14 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, LLVMContext &Context = newFunction->getContext(); // Add inputs as params, or to be filled into the struct - for (Values::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) + for (ValueSet::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) if (AggregateArgs) StructValues.push_back(*i); else params.push_back(*i); // Create allocas for the outputs - for (Values::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { + for (ValueSet::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { if (AggregateArgs) { StructValues.push_back(*i); } else { @@ -403,7 +420,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, AllocaInst *Struct = 0; if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { std::vector ArgTypes; - for (Values::iterator v = StructValues.begin(), + for (ValueSet::iterator v = StructValues.begin(), ve = StructValues.end(); v != ve; ++v) ArgTypes.push_back((*v)->getType()); @@ -458,7 +475,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, std::vector Users(outputs[i]->use_begin(), outputs[i]->use_end()); for (unsigned u = 0, e = Users.size(); u != e; ++u) { Instruction *inst = cast(Users[u]); - if (!BlocksToExtract.count(inst->getParent())) + if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } } @@ -476,11 +493,11 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, std::map ExitBlockMap; unsigned switchVal = 0; - for (SetVector::const_iterator i = BlocksToExtract.begin(), - e = BlocksToExtract.end(); i != e; ++i) { + for (SetVector::const_iterator i = Blocks.begin(), + e = Blocks.end(); i != e; ++i) { TerminatorInst *TI = (*i)->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (!BlocksToExtract.count(TI->getSuccessor(i))) { + if (!Blocks.count(TI->getSuccessor(i))) { BasicBlock *OldTarget = TI->getSuccessor(i); // add a new basic block which returns the appropriate value BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; @@ -624,12 +641,12 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, } void CodeExtractor::moveCodeToFunction(Function *newFunction) { - Function *oldFunc = (*BlocksToExtract.begin())->getParent(); + Function *oldFunc = (*Blocks.begin())->getParent(); Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - for (SetVector::const_iterator i = BlocksToExtract.begin(), - e = BlocksToExtract.end(); i != e; ++i) { + for (SetVector::const_iterator i = Blocks.begin(), + e = Blocks.end(); i != e; ++i) { // Delete the basic block from the old function, and the list of blocks oldBlocks.remove(*i); @@ -638,45 +655,22 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { } } -/// ExtractRegion - Removes a loop from a function, replaces it with a call to -/// new function. Returns pointer to the new function. -/// -/// algorithm: -/// -/// find inputs and outputs for the region -/// -/// for inputs: add to function as args, map input instr* to arg# -/// for outputs: add allocas for scalars, -/// add to func as args, map output instr* to arg# -/// -/// rewrite func to use argument #s instead of instr* -/// -/// for each scalar output in the function: at every exit, store intermediate -/// computed result back into memory. -/// -Function *CodeExtractor:: -ExtractCodeRegion(ArrayRef code) { - if (!isEligible(code)) +Function *CodeExtractor::extractCodeRegion() { + if (!isEligible()) return 0; - // 1) Find inputs, outputs - // 2) Construct new function - // * Add allocas for defs, pass as args by reference - // * Pass in uses as args - // 3) Move code region, add call instr to func - // - BlocksToExtract.insert(code.begin(), code.end()); - - Values inputs, outputs; + ValueSet inputs, outputs; // Assumption: this is a single-entry code region, and the header is the first // block in the region. - BasicBlock *header = code[0]; + BasicBlock *header = *Blocks.begin(); - for (unsigned i = 1, e = code.size(); i != e; ++i) - for (pred_iterator PI = pred_begin(code[i]), E = pred_end(code[i]); + for (SetVector::iterator BI = llvm::next(Blocks.begin()), + BE = Blocks.end(); + BI != BE; ++BI) + for (pred_iterator PI = pred_begin(*BI), E = pred_end(*BI); PI != E; ++PI) - assert(BlocksToExtract.count(*PI) && + assert(Blocks.count(*PI) && "No blocks in this region may have entries from outside the region" " except for the first block!"); @@ -718,7 +712,7 @@ ExtractCodeRegion(ArrayRef code) { for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { PHINode *PN = cast(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!BlocksToExtract.count(PN->getIncomingBlock(i))) + if (!Blocks.count(PN->getIncomingBlock(i))) PN->setIncomingBlock(i, newFuncRoot); } @@ -732,7 +726,7 @@ ExtractCodeRegion(ArrayRef code) { PHINode *PN = cast(I); std::set ProcessedPreds; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (BlocksToExtract.count(PN->getIncomingBlock(i))) { + if (Blocks.count(PN->getIncomingBlock(i))) { if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) PN->setIncomingBlock(i, codeReplacer); else { @@ -754,51 +748,3 @@ ExtractCodeRegion(ArrayRef code) { report_fatal_error("verifyFunction failed!")); return newFunction; } - -bool CodeExtractor::isEligible(ArrayRef code) { - for (ArrayRef::iterator I = code.begin(), E = code.end(); - I != E; ++I) - if (!isBlockViableForExtraction(**I)) - return false; - - return true; -} - -bool llvm::isBlockViableForExtraction(const BasicBlock &BB) { - // Landing pads must be in the function where they were inserted for cleanup. - if (BB.isLandingPad()) - return false; - - // Don't hoist code containing allocas, invokes, or vastarts. - for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { - if (isa(I) || isa(I)) - return false; - if (const CallInst *CI = dyn_cast(I)) - if (const Function *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::vastart) - return false; - } - - return true; -} - -/// ExtractCodeRegion - Slurp a sequence of basic blocks into a brand new -/// function. -/// -Function* llvm::ExtractCodeRegion(DominatorTree &DT, - ArrayRef code, - bool AggregateArgs) { - return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(code); -} - -/// ExtractLoop - Slurp a natural loop into a brand new function. -/// -Function* llvm::ExtractLoop(DominatorTree &DT, Loop *L, bool AggregateArgs) { - return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(L->getBlocks()); -} - -/// ExtractBasicBlock - Slurp a basic block into a brand new function. -/// -Function* llvm::ExtractBasicBlock(ArrayRef BBs, bool AggregateArgs){ - return CodeExtractor(0, AggregateArgs).ExtractCodeRegion(BBs); -} -- cgit v1.2.3-18-g5258