diff options
-rw-r--r-- | include/llvm/InitializePasses.h | 1 | ||||
-rw-r--r-- | include/llvm/Transforms/Scalar.h | 2 | ||||
-rw-r--r-- | include/llvm/Value.h | 6 | ||||
-rw-r--r-- | lib/CodeGen/Passes.cpp | 10 | ||||
-rw-r--r-- | lib/Transforms/Scalar/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lib/Transforms/Scalar/NaClCcRewrite.cpp | 866 |
6 files changed, 885 insertions, 1 deletions
diff --git a/include/llvm/InitializePasses.h b/include/llvm/InitializePasses.h index de97957a84..3c0ab0f33c 100644 --- a/include/llvm/InitializePasses.h +++ b/include/llvm/InitializePasses.h @@ -256,6 +256,7 @@ void initializeUnpackMachineBundlesPass(PassRegistry&); void initializeFinalizeMachineBundlesPass(PassRegistry&); void initializeBBVectorizePass(PassRegistry&); void initializeMachineFunctionPrinterPassPass(PassRegistry&); +void initializeNaClCcRewritePass(PassRegistry&); // @LOCALMOD } #endif diff --git a/include/llvm/Transforms/Scalar.h b/include/llvm/Transforms/Scalar.h index 3dce6fe37f..29b5233e22 100644 --- a/include/llvm/Transforms/Scalar.h +++ b/include/llvm/Transforms/Scalar.h @@ -366,7 +366,7 @@ extern char &InstructionSimplifierID; // "block_weights" metadata. FunctionPass *createLowerExpectIntrinsicPass(); - +FunctionPass *createNaClCcRewritePass(const TargetLowering *TLI = 0); } // End llvm namespace #endif diff --git a/include/llvm/Value.h b/include/llvm/Value.h index a82ac45c49..d7ccd4dccc 100644 --- a/include/llvm/Value.h +++ b/include/llvm/Value.h @@ -104,6 +104,12 @@ public: /// Type *getType() const { return VTy; } + // @LOCALMOD-START + // Currently only used for function type update during + // the NaCl calling convention rewrite pass + void setType(Type* t) { VTy = t; } + // @LOCALMOD-END + /// All values hold a context through their type. LLVMContext &getContext() const; diff --git a/lib/CodeGen/Passes.cpp b/lib/CodeGen/Passes.cpp index 56526f2732..d68c6740f8 100644 --- a/lib/CodeGen/Passes.cpp +++ b/lib/CodeGen/Passes.cpp @@ -352,6 +352,16 @@ void TargetPassConfig::addIRPasses() { addPass(createTypeBasedAliasAnalysisPass()); addPass(createBasicAliasAnalysisPass()); + // @LOCALMOD-START + addPass(createNaClCcRewritePass(TM->getTargetLowering())); + // TODO: consider adding a cleanup pass, e.g. constant propagation + // Note: we run this before the verfier step because it may cause + // a *temporary* inconsistency: + // A function may have been rewritting before we are rewriting + // its callers - which would lead to a parameter mismatch complaint + // from the verifier. + // @LOCALMOD-END + // Before running any passes, run the verifier to determine if the input // coming from the front-end and/or optimizer is valid. if (!DisableVerify) diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index a01e0661b1..283758f395 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_library(LLVMScalarOpts SimplifyLibCalls.cpp Sink.cpp TailRecursionElimination.cpp + NaClCcRewrite.cpp ) add_dependencies(LLVMScalarOpts intrinsics_gen) diff --git a/lib/Transforms/Scalar/NaClCcRewrite.cpp b/lib/Transforms/Scalar/NaClCcRewrite.cpp new file mode 100644 index 0000000000..6309f7e0b2 --- /dev/null +++ b/lib/Transforms/Scalar/NaClCcRewrite.cpp @@ -0,0 +1,866 @@ +//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements calling convention rewrite for Native Client to ensure +// compatibility between pnacl and gcc generated code when calling +// ppapi interface functions. +//===----------------------------------------------------------------------===// + + +// Major TODOs: +// * add register constraints to x86-64 rewrite decissions +// * dealing with vararg +// (We shoulf exclude all var arg functions and calls to them from rewrites) + +#define DEBUG_TYPE "naclcc" + +#include "llvm/Argument.h" +#include "llvm/Attributes.h" +#include "llvm/Constant.h" +#include "llvm/Instruction.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetLowering.h" +#include "llvm/Target/TargetLoweringObjectFile.h" +#include "llvm/Transforms/Scalar.h" + +#include <vector> + +using namespace llvm; + +namespace llvm { + +cl::opt<bool> FlagEnableCcRewrite( + "nacl-cc-rewrite", + cl::desc("enable NaCl CC rewrite")); +} + +namespace { + +// This represents a rule for rewiriting types +struct TypeRewriteRule { + const char* src; // type pattern we are trying to match + const char* dst; // replacement type + const char* name; // name of the rule for diagnosis +}; + +// Note: all rules must be well-formed +// * parentheses must match +// * TODO: add verification for this + +// Legend: +// s(): struct (also used for unions) +// c: char (= 8 bit int) (only allowed for src) +// i: 32 bit int +// l: 64 bit int +// f: 32 bit float +// d: 64 bit float (= double) +// p: untyped pointer (only allowed for src) +// P(): typed pointer (currently not used, only allowed for src) +// C: "copy", use src as dst (only allowed for dst and sret) +// F: generic function type (only allowed for src) + + +// The X8664 Rewrite rules are also subject to +// register constraints, c.f.: section 3.2.3 +// http://www.x86-64.org/documentation/abi.pdf +TypeRewriteRule ByvalRulesX8664[] = { + {"s(iis(d))", "ll", "PP_Var"}, + {"s(pp)", "l", "PP_ArrayOutput"}, + {"s(ppi)", "li", "PP_CompletionCallback"}, + {0, 0, 0}, +}; + +TypeRewriteRule SretRulesX8664[] = { + {"s(iis(d))", "s(ll)", "PP_Var"}, + {"s(ff)", "d", "PP_FloatPoint"}, + {"s(ii)", "l", "PP_Point" }, + {"s(pp)", "l", "PP_ArrayOutput"}, + {0, 0, 0}, +}; + +TypeRewriteRule ByvalRulesARM[] = { + {"s(iis(d))", "ll", "PP_Var"}, + {"s(ppi)", "iii", "PP_CompletionCallback" }, + {"s(pp)", "ii", "PP_ArrayOutput"}, + {0, 0, 0}, +}; + +TypeRewriteRule SretRulesARM[] = { + {"s(ff)", "C", "PP_FloatPoint"}, + {0, 0, 0}, +}; + +// TODO: Find a better way to determine the architecture +const TypeRewriteRule* GetByvalRewriteRulesForTarget( + const TargetLowering* tli) { + if (!FlagEnableCcRewrite) return 0; + + const TargetMachine &m = tli->getTargetMachine(); + const StringRef triple = m.getTargetTriple(); + + if (0 == triple.find("x86_64")) return ByvalRulesX8664; + if (0 == triple.find("i686")) return 0; + if (0 == triple.find("armv7a")) return ByvalRulesARM; + + llvm_unreachable("Unknown arch"); + return 0; +} + +// TODO: Find a better way to determine the architecture +const TypeRewriteRule* GetSretRewriteRulesForTarget( + const TargetLowering* tli) { + if (!FlagEnableCcRewrite) return 0; + + const TargetMachine &m = tli->getTargetMachine(); + const StringRef triple = m.getTargetTriple(); + + if (0 == triple.find("x86_64")) return SretRulesX8664; + if (0 == triple.find("i686")) return 0; + if (0 == triple.find("armv7a")) return SretRulesARM; + + llvm_unreachable("Unknown arch"); + return 0; +} + +// This class represents the a bitcode rewrite pass which ensures +// that all ppapi interfaces are calling convention compatible +// with gcc. This pass is archtitecture dependent. +struct NaClCcRewrite : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + + const TypeRewriteRule* SretRewriteRules; + const TypeRewriteRule* ByvalRewriteRules; + + explicit NaClCcRewrite(const TargetLowering *tli = 0) + : FunctionPass(ID), + SretRewriteRules(GetSretRewriteRulesForTarget(tli)), + ByvalRewriteRules(GetByvalRewriteRulesForTarget(tli)) { + initializeNaClCcRewritePass(*PassRegistry::getPassRegistry()); + } + + // main pass entry point + bool runOnFunction(Function &F); + + private: + void RewriteCallsite(Instruction* call, LLVMContext& C); + void RewriteFunctionPrologAndEpilog(Function& F); +}; + +char NaClCcRewrite::ID = 0; + +// This is only used for dst side of rules +Type* GetElementaryType(char c, LLVMContext& C) { + switch (c) { + case 'i': + return Type::getInt32Ty(C); + case 'l': + return Type::getInt64Ty(C); + case 'd': + return Type::getDoubleTy(C); + case 'f': + return Type::getFloatTy(C); + default: + dbgs() << c << "\n"; + llvm_unreachable("Unknown type specifier"); + return 0; + } +} + +// This is only used for the dst side of a rule +int GetElementaryTypeWidth(char c) { + switch (c) { + case 'i': + case 'f': + return 4; + case 'l': + case 'd': + return 8; + default: + llvm_unreachable("Unknown type specifier"); + return 0; + } +} + +// Check whether a type matches the *src* side pattern of a rewrite rule. +// Note that the pattern parameter is updated during the recursion +bool HasRewriteType(const Type* type, const char*& pattern) { + switch (*pattern++) { + case '\0': + return false; + case ')': + return false; + case 's': // struct and union are currently no distinguished + { + if (*pattern++ != '(') llvm_unreachable("malformed type pattern"); + if (!type->isStructTy()) return false; + // check struct members + const StructType* st = cast<StructType>(type); + for (StructType::element_iterator it = st->element_begin(), + end = st->element_end(); + it != end; + ++it) { + if (!HasRewriteType(*it, pattern)) return false; + } + // ensure we reached the end + int c = *pattern++; + return c == ')'; + } + break; + case 'c': + return type->isIntegerTy(8); + case 'i': + return type->isIntegerTy(32); + case 'l': + return type->isIntegerTy(64); + case 'd': + return type->isDoubleTy(); + case 'f': + return type->isFloatTy(); + case 'F': + return type->isFunctionTy(); + case 'p': // untyped pointer + return type->isPointerTy(); + case 'P': // typed pointer + { + if (*pattern++ != '(') llvm_unreachable("malformed type pattern"); + if (!type->isPointerTy()) return false; + Type* pointee = dyn_cast<PointerType>(type)->getElementType(); + if (!HasRewriteType(pointee, pattern)) return false; + int c = *pattern++; + return c == ')'; + } + default: + llvm_unreachable("Unknown type specifier"); + return false; + } +} + +// Match a type against a set of rewrite rules. +// Return the matching rule, if any. +const TypeRewriteRule* MatchRewriteRules( + const Type* type, const TypeRewriteRule* rules) { + if (rules == 0) return 0; + for (; rules->name != 0; ++rules) { + const char* pattern = rules->src; + if (HasRewriteType(type, pattern)) return rules; + } + return 0; +} + +// Same as MatchRewriteRules but "dereference" type first. +const TypeRewriteRule* MatchRewriteRulesPointee(const Type* t, + const TypeRewriteRule* Rules) { + // sret and byval are both modelled as pointers + const PointerType* pointer = dyn_cast<PointerType>(t); + if (pointer == 0) return 0; + + return MatchRewriteRules(pointer->getElementType(), Rules); +} + +// Note, the attributes are not part of the type but are stored +// with the CallInst and/or the Function (if any) +Type* CreateFunctionPointerType(Type* result_type, + std::vector<Type*>& arguments) { + FunctionType* ft = FunctionType::get(result_type, + arguments, + false); + return PointerType::getUnqual(ft); +} + +// Determines whether a function body needs a rewrite +bool FunctionNeedsRewrite(const Function* fun, + const TypeRewriteRule* ByvalRewriteRules, + const TypeRewriteRule* SretRewriteRules) { + // TODO: can this be detected on indirect callsites as well. + // if we skip the rewrite for the function body + // we also need to skip it at the callsites + // if (F.isVarArg()) return false; + + for (Function::const_arg_iterator AI = fun->arg_begin(), AE = fun->arg_end(); + AI != AE; + ++AI) { + const Argument& a = *AI; + const Type* t = a.getType(); + // byval and srets are modelled as pointers (to structs) + if (!t->isPointerTy()) continue; + Type* pointee = dyn_cast<PointerType>(t)->getElementType(); + + if (ByvalRewriteRules && a.hasByValAttr()) { + if (0 != MatchRewriteRules(pointee, ByvalRewriteRules)) return true; + } + + if (SretRewriteRules && a.hasStructRetAttr()) { + if (0 != MatchRewriteRules(pointee, SretRewriteRules)) return true; + } + } + return false; +} + +// Used for sret rewrites to determine the new function result type +Type* GetNewReturnType(Type* type, + const TypeRewriteRule* rule, + LLVMContext& C) { + if (std::string("C") == rule->dst) { + if (!type->isPointerTy()) { + llvm_unreachable("unexpected return type for C"); + } + Type* pointee = dyn_cast<PointerType>(type)->getElementType(); + return pointee; + } else if (std::string("l") == rule->dst || + std::string("d") == rule->dst) { + return GetElementaryType(rule->dst[0], C); + } else if (rule->dst[0] == 's') { + const char* cp = rule->dst + 2; // skip 's(' + std::vector<Type*> fields; + while (*cp != ')') { + fields.push_back(GetElementaryType(*cp, C)); + ++cp; + } + return StructType::get(C, fields, false /* isPacked */); + } else { + dbgs() << *type << " " << rule->name << "\n"; + llvm_unreachable("unexpected return type"); + return 0; + } +} + +// Rewrite sret parameter while rewriting a function +Type* RewriteFunctionSret(Function& F, + Value* orig_val, + const TypeRewriteRule* rule) { + LLVMContext& C = F.getContext(); + BasicBlock& entry = F.getEntryBlock(); + Instruction* before = &(entry.front()); + Type* old_type = orig_val->getType(); + Type* old_pointee = dyn_cast<PointerType>(old_type)->getElementType(); + Type* new_type = GetNewReturnType(old_type, rule, C); + // create a temporary to hold the return value as we no longer pass + // in the pointer + AllocaInst* tmp_ret = new AllocaInst(old_pointee, "result", before); + orig_val->replaceAllUsesWith(tmp_ret); + CastInst* cast_ret = CastInst::CreatePointerCast( + tmp_ret, + PointerType::getUnqual(new_type), + "byval_cast", + before); + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); + II != IE; + /* see below */) { + Instruction* inst = II; + // we do decontructive magic below, so advance the iterator here + // (this is still a little iffy) + ++II; + ReturnInst* ret = dyn_cast<ReturnInst>(inst); + if (ret) { + if (ret->getReturnValue() != 0) + llvm_unreachable("expected a void return"); + // load the return value from temporary + Value *ret_val = new LoadInst(cast_ret, "load_result", ret); + // return that loaded value and delete the return instruction + ReturnInst::Create(C, ret_val, ret); + ret->eraseFromParent(); + } + } + } + return new_type; +} + +// Rewrite one byval function parameter while rewriting a function +void FixFunctionByvalsParameter(Function& F, + std::vector<Argument*>& new_arguments, + std::vector<Attributes>& new_attributes, + Value* byval, + const TypeRewriteRule* rule) { + LLVMContext& C = F.getContext(); + BasicBlock& entry = F.getEntryBlock(); + Instruction* before = &(entry.front()); + Twine prefix = byval->getName() + "_split"; + Type* t = byval->getType(); + Type* pointee = dyn_cast<PointerType>(t)->getElementType(); + AllocaInst* tmp_param = new AllocaInst(pointee, prefix + "_param", before); + byval->replaceAllUsesWith(tmp_param); + // convert byval poiner to char pointer + Value* base = CastInst::CreatePointerCast( + tmp_param, PointerType::getInt8PtrTy(C), prefix + "_base", before); + + int width = 0; + const char* pattern = rule->dst; + for (int offset = 0; *pattern; ++pattern, offset += width) { + width = GetElementaryTypeWidth(*pattern); + Type* t = GetElementaryType(*pattern, C); + Argument* arg = new Argument(t, prefix, &F); + Type* pt = PointerType::getUnqual(t); + // the code below generates something like: + // <CHAR-PTR> = getelementptr i8* <BASE>, i32 <OFFSET-FROM-BASE> + // <PTR> = bitcast i8* <CHAR-PTR> to <TYPE>* + // store <ARG> <TYPE>* <ELEM-PTR> + ConstantInt* baseOffset = ConstantInt::get(Type::getInt32Ty(C), offset); + Value *v; + v = GetElementPtrInst::Create(base, baseOffset, prefix + "_base_add", before); + v = CastInst::CreatePointerCast(v, pt, prefix + "_cast", before); + v = new StoreInst(arg, v, before); + + new_arguments.push_back(arg); + new_attributes.push_back(Attribute::None); + } +} + +// Change function signature to reflect all the rewrites. +// This includes function type/signature and attributes. +void UpdateFunctionSignature(Function &F, + Type* new_result_type, + std::vector<Argument*>& new_arguments, + std::vector<Attributes>& new_attributes) { + DEBUG(dbgs() << "PHASE PROTOTYPE UPDATE\n"); + if (new_result_type) { + DEBUG(dbgs() << "NEW RESULT TYPE: " << *new_result_type << "\n"); + } + // Update function type + FunctionType* old_fun_type = F.getFunctionType(); + std::vector<Type*> new_types; + for (size_t i = 0; i < new_arguments.size(); ++i) { + new_types.push_back(new_arguments[i]->getType()); + } + + FunctionType* new_fun_type = FunctionType::get( + new_result_type ? new_result_type : old_fun_type->getReturnType(), + new_types, + false); + F.setType(PointerType::getUnqual(new_fun_type)); + + Function::ArgumentListType& args = F.getArgumentList(); + DEBUG(dbgs() << "PHASE ARGUMENT ERASE " << args.size() << "\n"); + while (args.size()) { + Argument* arg = args.remove(args.begin()); + } + + DEBUG(dbgs() << "PHASE ARGUMENT REFILL" << new_arguments.size() << "\n"); + for (size_t i = 0; i < new_arguments.size(); ++i) { + args.push_back(new_arguments[i]); + } + + DEBUG(dbgs() << "PHASE ATTRIBUTES UPDATE\n"); + // Update function attributes + std::vector<AttributeWithIndex> new_attributes_vec; + for (size_t i = 0; i < new_attributes.size(); ++i) { + Attributes attr = new_attributes[i]; + if (attr) { + // index 0 is for the return value + new_attributes_vec.push_back(AttributeWithIndex::get(i + 1, attr)); + } + } + if (Attributes attrs = F.getAttributes().getFnAttributes()) + new_attributes_vec.push_back(AttributeWithIndex::get(~0, attrs)); + + F.setAttributes(AttrListPtr::get(new_attributes_vec)); +} + +// Apply byval or sret rewrites to function body. +void NaClCcRewrite::RewriteFunctionPrologAndEpilog(Function& F) { + + DEBUG(dbgs() << "\nFUNCTION-REWRITE\n"); + + DEBUG(dbgs() << "FUNCTION BEFORE "); + DEBUG(dbgs() << F); + DEBUG(dbgs() << "\n"); + + std::vector<Argument*> new_arguments; + std::vector<Attributes> new_attributes; + std::vector<Argument*> old_arguments; + + // make copy first as create Argument adds them to the list + for (Function::arg_iterator ai = F.arg_begin(), + end = F.arg_end(); + ai != end; + ++ai) { + old_arguments.push_back(ai); + } + + for (size_t i = 0; i < old_arguments.size(); ++i) { + Argument* arg = old_arguments[i]; + Type* t = arg->getType(); + // index zero is for return value attributes + Attributes attr = F.getAttributes().getParamAttributes(i + 1); + const TypeRewriteRule* rule = 0; + if (attr & Attribute::ByVal) { + rule = MatchRewriteRulesPointee(t, ByvalRewriteRules); + } + if (rule == 0) { + new_arguments.push_back(arg); + new_attributes.push_back(attr); + continue; + } + DEBUG(dbgs() << "REWRITING BYVAL " + << *t << " arg " << arg->getName() << " " << rule->name << "\n"); + FixFunctionByvalsParameter(F, + new_arguments, + new_attributes, + arg, + rule); + } + + // A non-zero new_result_type indicates an sret rewrite + Type* new_result_type = 0; + // only the first arg can be "sret" + if (new_attributes[0] & Attribute::StructRet) { + const TypeRewriteRule* sret_rule = MatchRewriteRulesPointee( + new_arguments[0]->getType(), SretRewriteRules); + if (sret_rule) { + Argument* arg = F.getArgumentList().begin(); + DEBUG(dbgs() << "REWRITING SRET " + << " arg " << arg->getName() << " " << sret_rule->name << "\n"); + new_result_type = RewriteFunctionSret(F, arg, sret_rule); + new_arguments.erase(new_arguments.begin()); + new_attributes.erase(new_attributes.begin()); + } + } + + UpdateFunctionSignature(F, new_result_type, new_arguments, new_attributes); + + DEBUG(dbgs() << "FUNCTION AFTER "); + DEBUG(dbgs() << F); + DEBUG(dbgs() << "\n"); +} + +// used for T in {CallInst, InvokeInst} +template<class T> bool CallNeedsRewrite( + const Instruction* inst, + const TypeRewriteRule* ByvalRewriteRules, + const TypeRewriteRule* SretRewriteRules) { + + const T* call = cast<T>(inst); + // skip non parameter operands at the end + size_t num_params = call->getNumOperands() - + (isa<CallInst>(inst) ? 1 : 3); + for (size_t i = 0; i < num_params; ++i) { + Type* t = call->getOperand(i)->getType(); + // byval and srets are modelled as pointers (to structs) + if (!t->isPointerTy()) continue; + Type* pointee = dyn_cast<PointerType>(t)->getElementType(); + + // param zero is for the return value + if (ByvalRewriteRules && call->paramHasAttr(i + 1, Attribute::ByVal)) { + if (0 != MatchRewriteRules(pointee, ByvalRewriteRules)) return true; + } + + if (SretRewriteRules && call->paramHasAttr(i + 1, Attribute::StructRet)) { + if (0 != MatchRewriteRules(pointee, SretRewriteRules)) return true; + } + } + + return false; +} + +// This code will load the fields of the byval ptr into scalar variables +// which will then be used as argument when we rewrite the actual call +// instruction. +void PrependCompensationForByvals(std::vector<Value*>& new_operands, + std::vector<Attributes>& new_attributes, + Instruction* call, + Value* byval, + const TypeRewriteRule* rule, + LLVMContext& C) { + // convert byval poiner to char pointer + Value* base = CastInst::CreatePointerCast( + byval, PointerType::getInt8PtrTy(C), "byval_base", call); + + int width = 0; + const char* pattern = rule->dst; + for (int offset = 0; *pattern; ++pattern, offset += width) { + width = GetElementaryTypeWidth(*pattern); + Type* t = GetElementaryType(*pattern, C); + Type* pt = PointerType::getUnqual(t); + // the code below generates something like: + // <CHAR-PTR> = getelementptr i8* <BASE>, i32 <OFFSET-FROM-BASE> + // <PTR> = bitcast i8* <CHAR-PTR> to i32* + // <SCALAR> = load i32* <ELEM-PTR> + ConstantInt* baseOffset = ConstantInt::get(Type::getInt32Ty(C), offset); + Value* v; + v = GetElementPtrInst::Create(base, baseOffset, "byval_base_add", call); + v = CastInst::CreatePointerCast(v, pt, "byval_cast", call); + v = new LoadInst(v, "byval_extract", call); + + new_operands.push_back(v); + new_attributes.push_back(Attribute::None); + } +} + +// Note: this will only be called if we expect a rewrite to occur +void CallsiteFixupSrets(Instruction* call, + Value* sret, + Type* new_type, + const TypeRewriteRule* rule) { + const char* pattern = rule->dst; + Instruction* next= call->getNextNode(); + if (next == 0) { + llvm_unreachable("unexpected missing next instruction"); + } + + if (std::string("C") == pattern) { + // Note, this may store complex values, e.g. struct values, same code: + // store %struct.PP_FloatPoint <CALL-RESULT>, %struct.PP_FloatPoint* <SRET-PTR> + new StoreInst(call, sret, next); + } else if (pattern[0] == 's' || + std::string("l") == pattern || + std::string("d") == pattern) { + Type* pt = PointerType::getUnqual(new_type); + Value* cast = CastInst::CreatePointerCast(sret, pt, "cast", next); + new StoreInst(call, cast, next); + } else { + dbgs() << rule->name << "\n"; + llvm_unreachable("unexpected return type at fix up"); + } +} + +void ExtractOperandsAndAttributesFromCallInst( + CallInst* call, + std::vector<Value*>& operands, + std::vector<Attributes>& attributes) { + + AttrListPtr PAL = call->getAttributes(); + // last operand is: function + for (size_t i = 0; i < call->getNumOperands() - 1; ++i) { + operands.push_back(call->getArgOperand(i)); + // index zero is for return value attributes + attributes.push_back(PAL.getParamAttributes(i + 1)); + } +} + +// Note: this differs from the one above in the loop bounds +void ExtractOperandsAndAttributesFromeInvokeInst( + InvokeInst* call, + std::vector<Value*>& operands, + std::vector<Attributes>& attributes) { + AttrListPtr PAL = call->getAttributes(); + // last three operands are: function, bb-normal, bb-exception + for (size_t i = 0; i < call->getNumOperands() - 3; ++i) { + operands.push_back(call->getArgOperand(i)); + // index zero is for return value attributes + attributes.push_back(PAL.getParamAttributes(i + 1)); + } +} + + +Instruction* ReplaceCallInst(CallInst* call, + Type* function_pointer, + std::vector<Value*>& new_operands, + std::vector<Attributes>& new_attributes) { + Value* v = CastInst::CreatePointerCast( + call->getCalledValue(), function_pointer, "fp_cast", call); + CallInst* new_call = CallInst::Create(v, new_operands, "", call); + // NOTE: tail calls may be ruled out but byval/sret, should we assert this? + // TODO: did wid forget to clone anything else? + new_call->setTailCall(call->isTailCall()); + new_call->setCallingConv(call->getCallingConv()); + for (size_t i = 0; i < new_attributes.size(); ++i) { + // index zero is for return value attributes + new_call->addAttribute(i + 1, new_attributes[i]); + } + return new_call; +} + +Instruction* ReplaceInvokeInst(InvokeInst* call, + Type* function_pointer, + std::vector<Value*>& new_operands, + std::vector<Attributes>& new_attributes) { + Value* v = CastInst::CreatePointerCast( + call->getCalledValue(), function_pointer, "fp_cast", call); + InvokeInst* new_call = InvokeInst::Create(v, + call->getNormalDest(), + call->getUnwindDest(), + new_operands, + "", + call); + for (size_t i = 0; i < new_attributes.size(); ++i) { + // index zero is for return value attributes + new_call->addAttribute(i + 1, new_attributes[i]); + } + return new_call; +} + + +void NaClCcRewrite::RewriteCallsite(Instruction* call, LLVMContext& C) { + BasicBlock* BB = call->getParent(); + + DEBUG(dbgs() << "\nCALLSITE-REWRITE\n"); + DEBUG(dbgs() << "CALLSITE BB BEFORE " << *BB); + DEBUG(dbgs() << "\n"); + DEBUG(dbgs() << *call << "\n"); + + // new_result(_type) is only relevent if an sret is rewritten + // whish is indicated by sret_rule != 0 + const TypeRewriteRule* sret_rule = 0; + Type* new_result_type = call->getType(); + Value* new_result = 0; + + std::vector<Value*> old_operands; + std::vector<Attributes> old_attributes; + if (isa<CallInst>(call)) { + ExtractOperandsAndAttributesFromCallInst( + cast<CallInst>(call), old_operands, old_attributes); + } else if (isa<InvokeInst>(call)) { + ExtractOperandsAndAttributesFromeInvokeInst( + cast<InvokeInst>(call), old_operands, old_attributes); + } else { + llvm_unreachable("Unexpected instruction type"); + } + + std::vector<Value*> new_operands; + std::vector<Attributes> new_attributes; + + for (size_t i = 0; i < old_operands.size(); ++i) { + Value *operand = old_operands[i]; + Type* t = operand->getType(); + const TypeRewriteRule* rule = 0; + if (old_attributes[i] & Attribute::ByVal) { + rule = MatchRewriteRulesPointee(t, ByvalRewriteRules); + } + if (rule == 0) { + new_operands.push_back(operand); + new_attributes.push_back(old_attributes[i]); + continue; + } + + DEBUG(dbgs() << "REWRITING BYVAL " + << *t << " arg " << i << " " << rule->name << "\n"); + PrependCompensationForByvals(new_operands, + new_attributes, + call, + operand, + rule, + C); + } + + // only the first arg can be "sret" + if (new_attributes[0] & Attribute::StructRet) { + sret_rule = MatchRewriteRulesPointee( + new_operands[0]->getType(), SretRewriteRules); + } + + // we have to patch the call before we can add the sret compensation code + // because otherwise the type checker complains + if (sret_rule) { + new_result_type = GetNewReturnType(new_operands[0]->getType(), sret_rule, C); + new_result = new_operands[0]; + new_operands.erase(new_operands.begin()); + new_attributes.erase(new_attributes.begin()); + } + + // Note, this code is tricky. + // Initially we used a much more elaborate scheme introducing + // new function declarations for direct calls. + // This simpler scheme, however, works for both direct and + // indirect calls + // We transform (here the direct case): + // call void @result_PP_FloatPoint(%struct.PP_FloatPoint* sret %sret) + // into + // %fp_cast = bitcast void (%struct.PP_FloatPoint*)* + // @result_PP_FloatPoint to %struct.PP_FloatPoint ()* + // %result = call %struct.PP_FloatPoint %fp_cast() + // + std::vector<Type*> new_arg_types; + for (size_t i = 0; i < new_operands.size(); ++i) { + new_arg_types.push_back(new_operands[i]->getType()); + } + + DEBUG(dbgs() << "REWRITE CALL INSTRUCTION\n"); + Instruction* new_call = 0; + if (isa<CallInst>(call)) { + new_call = ReplaceCallInst( + cast<CallInst>(call), + CreateFunctionPointerType(new_result_type, new_arg_types), + new_operands, + new_attributes); + } else if (isa<InvokeInst>(call)) { + new_call = ReplaceInvokeInst( + cast<InvokeInst>(call), + CreateFunctionPointerType(new_result_type, new_arg_types), + new_operands, + new_attributes); + } else { + llvm_unreachable("Unexpected instruction type"); + } + + // We prepended the new call, now get rid of the old one. + // If we did not change the return type, there may be consumers + // of the result which must be redirected. + if (!sret_rule) { + call->replaceAllUsesWith(new_call); + } + call->eraseFromParent(); + + // Add compensation codes for srets if necessary + if (sret_rule) { + DEBUG(dbgs() << "REWRITING SRET " << sret_rule->name << "\n"); + CallsiteFixupSrets(new_call, new_result, new_result_type, sret_rule); + } + + DEBUG(dbgs() << "CALLSITE BB AFTER" << *BB); + DEBUG(dbgs() << "\n"); + DEBUG(dbgs() << *new_call << "\n"); +} + +bool NaClCcRewrite::runOnFunction(Function &F) { + // No rules - no action + if (ByvalRewriteRules == 0 && SretRewriteRules == 0) return false; + + bool Changed = false; + + if (FunctionNeedsRewrite(&F, ByvalRewriteRules, SretRewriteRules)) { + DEBUG(dbgs() << "FUNCTION NEEDS REWRITE " << F.getName() << "\n"); + RewriteFunctionPrologAndEpilog(F); + Changed = true; + } + + // Find all the calls and invokes in F and rewrite them if necessary + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); + II != IE; + /* II updated below */) { + Instruction* inst = II; + // we do decontructive magic below, so advance the iterator here + // (this is still a little iffy) + ++II; + + if (isa<InvokeInst>(inst) || isa<CallInst>(inst)) { + if (isa<CallInst>(inst) && + !CallNeedsRewrite<CallInst> + (inst, ByvalRewriteRules, SretRewriteRules)) continue; + + if (isa<InvokeInst>(inst) && + !CallNeedsRewrite<InvokeInst> + (inst, ByvalRewriteRules, SretRewriteRules)) continue; + + RewriteCallsite(inst, F.getContext()); + Changed = true; + } + } + } + return Changed; +} + +} // end anonymous namespace + + +INITIALIZE_PASS(NaClCcRewrite, "naclcc", "NaCl CC Rewriter", false, false) + +FunctionPass *llvm::createNaClCcRewritePass(const TargetLowering *tli) { + return new NaClCcRewrite(tli); +} + |