diff options
Diffstat (limited to 'lib/Transforms')
-rw-r--r-- | lib/Transforms/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lib/Transforms/IPO/ExtractGV.cpp | 18 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 6 | ||||
-rw-r--r-- | lib/Transforms/LLVMBuild.txt | 2 | ||||
-rw-r--r-- | lib/Transforms/Makefile | 6 | ||||
-rw-r--r-- | lib/Transforms/NaCl/CMakeLists.txt | 5 | ||||
-rw-r--r-- | lib/Transforms/NaCl/ExpandCtors.cpp | 145 | ||||
-rw-r--r-- | lib/Transforms/NaCl/LLVMBuild.txt | 23 | ||||
-rw-r--r-- | lib/Transforms/NaCl/Makefile | 15 | ||||
-rw-r--r-- | lib/Transforms/Scalar/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lib/Transforms/Scalar/NaClCcRewrite.cpp | 1053 |
11 files changed, 1270 insertions, 5 deletions
diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index de1353e6c1..9fa690971a 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(Scalar) add_subdirectory(IPO) add_subdirectory(Vectorize) add_subdirectory(Hello) +add_subdirectory(NaCl) diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index 4c7f0ed236..b2748f2e6c 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -58,6 +58,15 @@ namespace { continue; if (I->getName() == "llvm.global_ctors") continue; + // @LOCALMOD-BEGIN - this is likely upstreamable + // Note: there will likely be more cases once this + // is exercises more thorougly. + if (I->getName() == "llvm.global_dtors") + continue; + // not observed yet + if (I->hasExternalWeakLinkage()) + continue; + // @LOCALMOD-END } if (I->hasLocalLinkage()) @@ -72,8 +81,15 @@ namespace { } else { if (I->hasAvailableExternallyLinkage()) continue; + // @LOCALMOD-BEGIN - this is likely upstreamable + // Note: there will likely be more cases once this + // is exercises more thorougly. + // observed for pthread_cancel + if (I->hasExternalWeakLinkage()) + continue; + // @LOCALMOD-END } - + if (I->hasLocalLinkage()) I->setVisibility(GlobalValue::HiddenVisibility); I->setLinkage(GlobalValue::ExternalLinkage); diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 359bc488f3..0958842d08 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1148,8 +1148,10 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // If we are removing arguments to the function, emit an obnoxious warning. if (FT->getNumParams() < NumActualArgs) { if (!FT->isVarArg()) { - errs() << "WARNING: While resolving call to function '" - << Callee->getName() << "' arguments were dropped!\n"; + if (Callee->getName() != "main") { // @LOCALMOD + errs() << "WARNING: While resolving call to function '" + << Callee->getName() << "' arguments were dropped!\n"; + } } else { // Add all of the arguments in their promoted form to the arg list. for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { diff --git a/lib/Transforms/LLVMBuild.txt b/lib/Transforms/LLVMBuild.txt index f7bca064c7..001ba5d232 100644 --- a/lib/Transforms/LLVMBuild.txt +++ b/lib/Transforms/LLVMBuild.txt @@ -16,7 +16,7 @@ ;===------------------------------------------------------------------------===; [common] -subdirectories = IPO InstCombine Instrumentation Scalar Utils Vectorize +subdirectories = IPO InstCombine Instrumentation Scalar Utils Vectorize NaCl [component_0] type = Group diff --git a/lib/Transforms/Makefile b/lib/Transforms/Makefile index 8b1df92fa2..ae03ff32c5 100644 --- a/lib/Transforms/Makefile +++ b/lib/Transforms/Makefile @@ -8,7 +8,11 @@ ##===----------------------------------------------------------------------===## LEVEL = ../.. -PARALLEL_DIRS = Utils Instrumentation Scalar InstCombine IPO Vectorize Hello +PARALLEL_DIRS = Utils Instrumentation Scalar InstCombine IPO Vectorize Hello NaCl + +ifeq ($(NACL_SANDBOX),1) + PARALLEL_DIRS := $(filter-out Hello, $(PARALLEL_DIRS)) +endif include $(LEVEL)/Makefile.config diff --git a/lib/Transforms/NaCl/CMakeLists.txt b/lib/Transforms/NaCl/CMakeLists.txt new file mode 100644 index 0000000000..d634ad9655 --- /dev/null +++ b/lib/Transforms/NaCl/CMakeLists.txt @@ -0,0 +1,5 @@ +add_llvm_library(LLVMTransformsNaCl + ExpandCtors.cpp + ) + +add_dependencies(LLVMTransformsNaCl intrinsics_gen) diff --git a/lib/Transforms/NaCl/ExpandCtors.cpp b/lib/Transforms/NaCl/ExpandCtors.cpp new file mode 100644 index 0000000000..6b8130e4fb --- /dev/null +++ b/lib/Transforms/NaCl/ExpandCtors.cpp @@ -0,0 +1,145 @@ +//===- ExpandCtors.cpp - Convert ctors/dtors to concrete arrays -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass converts LLVM's special symbols llvm.global_ctors and +// llvm.global_dtors to concrete arrays, __init_array_start/end and +// __fini_array_start/end, that are usable by a C library. +// +// This pass sorts the contents of global_ctors/dtors according to the +// priority values they contain and removes the priority values. +// +//===----------------------------------------------------------------------===// + +#include <vector> + +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/NaCl.h" +#include "llvm/TypeBuilder.h" + +using namespace llvm; + +namespace { + struct ExpandCtors : public ModulePass { + static char ID; // Pass identification, replacement for typeid + ExpandCtors() : ModulePass(ID) { + initializeExpandCtorsPass(*PassRegistry::getPassRegistry()); + } + + virtual bool runOnModule(Module &M); + }; +} + +char ExpandCtors::ID = 0; +INITIALIZE_PASS(ExpandCtors, "nacl-expand-ctors", + "Hook up constructor and destructor arrays to libc", + false, false) + +static void setGlobalVariableValue(Module &M, const char *Name, + Constant *Value) { + GlobalVariable *Var = M.getNamedGlobal(Name); + if (!Var) { + // This warning can happen in a program that does not use a libc + // and so does not call the functions in __init_array_start or + // __fini_array_end. Such a program might be linked with + // "-nostdlib". + errs() << "Warning: Variable " << Name << " not referenced\n"; + } else { + if (Var->hasInitializer()) { + report_fatal_error(std::string("Variable ") + Name + + " already has an initializer"); + } + Var->replaceAllUsesWith(ConstantExpr::getBitCast(Value, Var->getType())); + Var->eraseFromParent(); + } +} + +struct FuncArrayEntry { + uint64_t priority; + Constant *func; +}; + +static bool compareEntries(FuncArrayEntry Entry1, FuncArrayEntry Entry2) { + return Entry1.priority < Entry2.priority; +} + +static void defineFuncArray(Module &M, const char *LlvmArrayName, + const char *StartSymbol, + const char *EndSymbol) { + std::vector<Constant*> Funcs; + + GlobalVariable *Array = M.getNamedGlobal(LlvmArrayName); + if (Array) { + if (Array->hasInitializer() && !Array->getInitializer()->isNullValue()) { + ConstantArray *InitList = cast<ConstantArray>(Array->getInitializer()); + std::vector<FuncArrayEntry> FuncsToSort; + for (unsigned Index = 0; Index < InitList->getNumOperands(); ++Index) { + ConstantStruct *CS = cast<ConstantStruct>(InitList->getOperand(Index)); + FuncArrayEntry Entry; + Entry.priority = cast<ConstantInt>(CS->getOperand(0))->getZExtValue(); + Entry.func = CS->getOperand(1); + FuncsToSort.push_back(Entry); + } + + std::sort(FuncsToSort.begin(), FuncsToSort.end(), compareEntries); + for (std::vector<FuncArrayEntry>::iterator Iter = FuncsToSort.begin(); + Iter != FuncsToSort.end(); + ++Iter) { + Funcs.push_back(Iter->func); + } + } + // No code should be referencing global_ctors/global_dtors, + // because this symbol is internal to LLVM. + Array->eraseFromParent(); + } + + Type *FuncTy = FunctionType::get(Type::getVoidTy(M.getContext()), false); + Type *FuncPtrTy = FuncTy->getPointerTo(); + ArrayType *ArrayTy = ArrayType::get(FuncPtrTy, Funcs.size()); + GlobalVariable *NewArray = + new GlobalVariable(M, ArrayTy, /* isConstant= */ true, + GlobalValue::InternalLinkage, + ConstantArray::get(ArrayTy, Funcs)); + setGlobalVariableValue(M, StartSymbol, NewArray); + // We do this last so that LLVM gives NewArray the name + // "__{init,fini}_array_start" without adding any suffixes to + // disambiguate from the original GlobalVariable's name. This is + // not essential -- it just makes the output easier to understand + // when looking at symbols for debugging. + NewArray->setName(StartSymbol); + + // We replace "__{init,fini}_array_end" with the address of the end + // of NewArray. This removes the name "__{init,fini}_array_end" + // from the output, which is not ideal for debugging. Ideally we + // would convert "__{init,fini}_array_end" to being a GlobalAlias + // that points to the end of the array. However, unfortunately LLVM + // does not generate correct code when a GlobalAlias contains a + // GetElementPtr ConstantExpr. + Constant *NewArrayEnd = + ConstantExpr::getGetElementPtr(NewArray, + ConstantInt::get(M.getContext(), + APInt(32, 1))); + setGlobalVariableValue(M, EndSymbol, NewArrayEnd); +} + +bool ExpandCtors::runOnModule(Module &M) { + defineFuncArray(M, "llvm.global_ctors", + "__init_array_start", "__init_array_end"); + defineFuncArray(M, "llvm.global_dtors", + "__fini_array_start", "__fini_array_end"); + return true; +} + +ModulePass *llvm::createExpandCtorsPass() { + return new ExpandCtors(); +} diff --git a/lib/Transforms/NaCl/LLVMBuild.txt b/lib/Transforms/NaCl/LLVMBuild.txt new file mode 100644 index 0000000000..2f1522b3e5 --- /dev/null +++ b/lib/Transforms/NaCl/LLVMBuild.txt @@ -0,0 +1,23 @@ +;===- ./lib/Transforms/NaCl/LLVMBuild.txt ----------------------*- Conf -*--===; +; +; The LLVM Compiler Infrastructure +; +; This file is distributed under the University of Illinois Open Source +; License. See LICENSE.TXT for details. +; +;===------------------------------------------------------------------------===; +; +; This is an LLVMBuild description file for the components in this subdirectory. +; +; For more information on the LLVMBuild system, please see: +; +; http://llvm.org/docs/LLVMBuild.html +; +;===------------------------------------------------------------------------===; + +[component_0] +type = Library +name = NaCl +parent = Transforms +library_name = NaCl +required_libraries = Core diff --git a/lib/Transforms/NaCl/Makefile b/lib/Transforms/NaCl/Makefile new file mode 100644 index 0000000000..ecf8db6eae --- /dev/null +++ b/lib/Transforms/NaCl/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/NaCl/Makefile-------------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file is distributed under the University of Illinois Open Source +# License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMTransformsNaCl +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index b3fc6e338c..06ef4b4a9b 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -32,6 +32,7 @@ add_llvm_library(LLVMScalarOpts SimplifyLibCalls.cpp Sink.cpp TailRecursionElimination.cpp + NaClCcRewrite.cpp ) add_dependencies(LLVMScalarOpts intrinsics_gen) diff --git a/lib/Transforms/Scalar/NaClCcRewrite.cpp b/lib/Transforms/Scalar/NaClCcRewrite.cpp new file mode 100644 index 0000000000..5eace7f39d --- /dev/null +++ b/lib/Transforms/Scalar/NaClCcRewrite.cpp @@ -0,0 +1,1053 @@ +//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements calling convention rewrite for Native Client to ensure +// compatibility between pnacl and gcc generated code when calling +// ppapi interface functions. +//===----------------------------------------------------------------------===// + + +// Major TODOs: +// * dealing with vararg +// (We shoulf exclude all var arg functions and calls to them from rewrites) + +#define DEBUG_TYPE "naclcc" + +#include "llvm/Argument.h" +#include "llvm/Attributes.h" +#include "llvm/Constant.h" +#include "llvm/DataLayout.h" +#include "llvm/Instruction.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetLowering.h" +#include "llvm/Target/TargetLoweringObjectFile.h" +#include "llvm/Transforms/Scalar.h" + +#include <vector> + +using namespace llvm; + +namespace llvm { + +cl::opt<bool> FlagEnableCcRewrite( + "nacl-cc-rewrite", + cl::desc("enable NaCl CC rewrite")); +} + +namespace { + +// This represents a rule for rewiriting types +struct TypeRewriteRule { + const char* src; // type pattern we are trying to match + const char* dst; // replacement type + const char* name; // name of the rule for diagnosis +}; + +// Note: all rules must be well-formed +// * parentheses must match +// * TODO: add verification for this + +// Legend: +// s(): struct (also used for unions) +// c: char (= 8 bit int) (only allowed for src) +// i: 32 bit int +// l: 64 bit int +// f: 32 bit float +// d: 64 bit float (= double) +// p: untyped pointer (only allowed for src) +// P(): typed pointer (currently not used, only allowed for src) +// F: generic function type (only allowed for src) + +// The X8664 Rewrite rules are also subject to +// register constraints, c.f.: section 3.2.3 +// http://www.x86-64.org/documentation/abi.pdf +// (roughly) for X8664: up to 2 regs per struct can be used for struct passsing +// and up to 2 regs for struct returns +// The rewrite rules are straight forward except for: s(iis(d)) => ll +// which would be straight forward if the frontend had lowered the union inside +// of PP_Var to s(l) instead of s(d), yielding: s(iis(l)) => ll +TypeRewriteRule ByvalRulesX8664[] = { + {"s(iis(d))", "ll", "PP_Var"}, + {"s(pp)", "l", "PP_ArrayOutput"}, + {"s(ppi)", "li", "PP_CompletionCallback"}, + {0, 0, 0}, +}; + +TypeRewriteRule SretRulesX8664[] = { + // Note: for srets, multireg returns are modeled as struct returns + {"s(iis(d))", "s(ll)", "PP_Var"}, + {"s(ff)", "d", "PP_FloatPoint"}, + {"s(ii)", "l", "PP_Point" }, + {"s(pp)", "l", "PP_ArrayOutput"}, + {0, 0, 0}, +}; + +// for ARM: up to 4 regs can be used for struct passsing +// and up to 2 float regs for struct returns +TypeRewriteRule ByvalRulesARM[] = { + {"s(iis(d))", "ll", "PP_Var"}, + {"s(ppi)", "iii", "PP_CompletionCallback" }, + {"s(pp)", "ii", "PP_ArrayOutput"}, + {0, 0, 0}, +}; + +TypeRewriteRule SretRulesARM[] = { + // Note: for srets, multireg returns are modeled as struct returns + {"s(ff)", "s(ff)", "PP_FloatPoint"}, + {0, 0, 0}, +}; + +// Helper class to model Register Usage as required by +// the x86-64 calling conventions +class RegUse { + uint32_t n_int_; + uint32_t n_float_; + + public: + RegUse(uint32_t n_int=0, uint32_t n_float=0) : + n_int_(n_int), n_float_(n_float) {} + + static RegUse OneIntReg() { return RegUse(1, 0); } + static RegUse OnePointerReg() { return RegUse(1, 0); } + static RegUse OneFloatReg() { return RegUse(0, 1); } + + RegUse operator+(RegUse other) const { + return RegUse(n_int_ + other.n_int_, n_float_ + other.n_float_); } + RegUse operator-(RegUse other) const { + return RegUse(n_int_ - other.n_int_, n_float_ - other.n_float_); } + bool operator==(RegUse other) const { + return n_int_ == other.n_int_ && n_float_ == other.n_float_; } + bool operator!=(RegUse other) const { + return n_int_ != other.n_int_ && n_float_ != other.n_float_; } + bool operator<=(RegUse other) const { + return n_int_ <= other.n_int_ && n_float_ <= other.n_float_; } + bool operator<(RegUse other) const { + return n_int_ < other.n_int_ && n_float_ < other.n_float_; } + bool operator>=(RegUse other) const { + return n_int_ >= other.n_int_ && n_float_ >= other.n_float_; } + bool operator>(RegUse other) const { + return n_int_ > other.n_int_ && n_float_ > other.n_float_; } + RegUse& operator+=(const RegUse& other) { + n_int_ += other.n_int_; n_float_ += other.n_float_; return *this;} + RegUse& operator-=(const RegUse& other) { + n_int_ -= other.n_int_; n_float_ -= other.n_float_; return *this;} + + friend raw_ostream& operator<<(raw_ostream &O, const RegUse& reg); +}; + +raw_ostream& operator<<(raw_ostream &O, const RegUse& reg) { + O << "(" << reg.n_int_ << ", " << reg.n_float_ << ")"; + return O; +} + +// TODO: Find a better way to determine the architecture +const TypeRewriteRule* GetByvalRewriteRulesForTarget( + const TargetLowering* tli) { + if (!FlagEnableCcRewrite) return 0; + + const TargetMachine &m = tli->getTargetMachine(); + const StringRef triple = m.getTargetTriple(); + + if (0 == triple.find("x86_64")) return ByvalRulesX8664; + if (0 == triple.find("i686")) return 0; + if (0 == triple.find("armv7a")) return ByvalRulesARM; + + llvm_unreachable("Unknown arch"); + return 0; +} + +// TODO: Find a better way to determine the architecture +const TypeRewriteRule* GetSretRewriteRulesForTarget( + const TargetLowering* tli) { + if (!FlagEnableCcRewrite) return 0; + + const TargetMachine &m = tli->getTargetMachine(); + const StringRef triple = m.getTargetTriple(); + + if (0 == triple.find("x86_64")) return SretRulesX8664; + if (0 == triple.find("i686")) return 0; + if (0 == triple.find("armv7a")) return SretRulesARM; + + llvm_unreachable("Unknown arch"); + return 0; +} + +// TODO: Find a better way to determine the architecture +// Describes the number of registers available for function +// argument passing which may affect rewrite decisions on +// some platforms. +RegUse GetAvailableRegsForTarget( + const TargetLowering* tli) { + if (!FlagEnableCcRewrite) return RegUse(0, 0); + + const TargetMachine &m = tli->getTargetMachine(); + const StringRef triple = m.getTargetTriple(); + + // integer: RDI, RSI, RDX, RCX, R8, R9 + // float XMM0, ..., XMM7 + if (0 == triple.find("x86_64")) return RegUse(6, 8); + // unused + if (0 == triple.find("i686")) return RegUse(0, 0); + // no constraints enforced here - the backend handles all the details + uint32_t max = std::numeric_limits<uint32_t>::max(); + if (0 == triple.find("armv7a")) return RegUse(max, max); + + llvm_unreachable("Unknown arch"); + return 0; +} + +// This class represents the a bitcode rewrite pass which ensures +// that all ppapi interfaces are calling convention compatible +// with gcc. This pass is archtitecture dependent. +struct NaClCcRewrite : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + const TypeRewriteRule* SretRewriteRules; + const TypeRewriteRule* ByvalRewriteRules; + const RegUse AvailableRegs; + + explicit NaClCcRewrite(const TargetLowering *tli = 0) + : FunctionPass(ID), + SretRewriteRules(GetSretRewriteRulesForTarget(tli)), + ByvalRewriteRules(GetByvalRewriteRulesForTarget(tli)), + AvailableRegs(GetAvailableRegsForTarget(tli)) { + initializeNaClCcRewritePass(*PassRegistry::getPassRegistry()); + } + + // main pass entry point + bool runOnFunction(Function &F); + + private: + void RewriteCallsite(Instruction* call, LLVMContext& C); + void RewriteFunctionPrologAndEpilog(Function& F); +}; + +char NaClCcRewrite::ID = 0; + +// This is only used for dst side of rules +Type* GetElementaryType(char c, LLVMContext& C) { + switch (c) { + case 'i': + return Type::getInt32Ty(C); + case 'l': + return Type::getInt64Ty(C); + case 'd': + return Type::getDoubleTy(C); + case 'f': + return Type::getFloatTy(C); + default: + dbgs() << c << "\n"; + llvm_unreachable("Unknown type specifier"); + return 0; + } +} + +// This is only used for the dst side of a rule +int GetElementaryTypeWidth(char c) { + switch (c) { + case 'i': + case 'f': + return 4; + case 'l': + case 'd': + return 8; + default: + llvm_unreachable("Unknown type specifier"); + return 0; + } +} + +// Check whether a type matches the *src* side pattern of a rewrite rule. +// Note that the pattern parameter is updated during the recursion +bool HasRewriteType(const Type* type, const char*& pattern) { + switch (*pattern++) { + case '\0': + return false; + case ')': + return false; + case 's': // struct and union are currently no distinguished + { + if (*pattern++ != '(') llvm_unreachable("malformed type pattern"); + if (!type->isStructTy()) return false; + // check struct members + const StructType* st = cast<StructType>(type); + for (StructType::element_iterator it = st->element_begin(), + end = st->element_end(); + it != end; + ++it) { + if (!HasRewriteType(*it, pattern)) return false; + } + // ensure we reached the end + int c = *pattern++; + return c == ')'; + } + break; + case 'c': + return type->isIntegerTy(8); + case 'i': + return type->isIntegerTy(32); + case 'l': + return type->isIntegerTy(64); + case 'd': + return type->isDoubleTy(); + case 'f': + return type->isFloatTy(); + case 'F': + return type->isFunctionTy(); + case 'p': // untyped pointer + return type->isPointerTy(); + case 'P': // typed pointer + { + if (*pattern++ != '(') llvm_unreachable("malformed type pattern"); + if (!type->isPointerTy()) return false; + Type* pointee = dyn_cast<PointerType>(type)->getElementType(); + if (!HasRewriteType(pointee, pattern)) return false; + int c = *pattern++; + return c == ')'; + } + default: + llvm_unreachable("Unknown type specifier"); + return false; + } +} + +RegUse RegUseForRewriteRule(const TypeRewriteRule* rule) { + const char* pattern = std::string("C") == rule->dst ? rule->src : rule->dst; + RegUse result(0, 0); + while (char c = *pattern++) { + // Note, we only support a subset here, complex types (s, P) + // would require more work + switch (c) { + case 'i': + case 'l': + result += RegUse::OneIntReg(); + break; + case 'd': + case 'f': + result += RegUse::OneFloatReg(); + break; + default: + dbgs() << c << "\n"; + llvm_unreachable("unexpected return type"); + } + } + return result; +} + +// Note, this only has to be accurate for x86-64 and is intentionally +// quite strict so that we know when to add support for new types. +// Ideally, unexpected types would be flagged by a bitcode checker. +RegUse RegUseForType(const Type* t) { + if (t->isPointerTy()) { + return RegUse::OnePointerReg(); + } else if (t->isFloatTy() || t->isDoubleTy()) { + return RegUse::OneFloatReg(); + } else if (t->isIntegerTy()) { + const IntegerType* it = dyn_cast<const IntegerType>(t); + unsigned width = it->getBitWidth(); + // x86-64 assumption here - use "register info" to make this better + if (width <= 64) return RegUse::OneIntReg(); + } + + dbgs() << *const_cast<Type*>(t) << "\n"; + llvm_unreachable("unexpected type in RegUseForType"); +} + +// Match a type against a set of rewrite rules. +// Return the matching rule, if any. +const TypeRewriteRule* MatchRewriteRules( + const Type* type, const TypeRewriteRule* rules) { + if (rules == 0) return 0; + for (; rules->name != 0; ++rules) { + const char* pattern = rules->src; + if (HasRewriteType(type, pattern)) return rules; + } + return 0; +} + +// Same as MatchRewriteRules but "dereference" type first. +const TypeRewriteRule* MatchRewriteRulesPointee(const Type* t, + const TypeRewriteRule* Rules) { + // sret and byval are both modelled as pointers + const PointerType* pointer = dyn_cast<PointerType>(t); + if (pointer == 0) return 0; + + return MatchRewriteRules(pointer->getElementType(), Rules); +} + +// Note, the attributes are not part of the type but are stored +// with the CallInst and/or the Function (if any) +Type* CreateFunctionPointerType(Type* result_type, + std::vector<Type*>& arguments) { + FunctionType* ft = FunctionType::get(result_type, + arguments, + false); + return PointerType::getUnqual(ft); +} + +// Determines whether a function body needs a rewrite +bool FunctionNeedsRewrite(const Function* fun, + const TypeRewriteRule* ByvalRewriteRules, + const TypeRewriteRule* SretRewriteRules, + RegUse available) { + // TODO: can this be detected on indirect callsites as well. + // if we skip the rewrite for the function body + // we also need to skip it at the callsites + // if (F.isVarArg()) return false; + + // Vectors and Arrays are not supported for compatibility + for (Function::const_arg_iterator AI = fun->arg_begin(), AE = fun->arg_end(); + AI != AE; + ++AI) { + const Type* t = AI->getType(); + if (isa<VectorType>(t) || isa<ArrayType>(t)) return false; + } + + for (Function::const_arg_iterator AI = fun->arg_begin(), AE = fun->arg_end(); + AI != AE; + ++AI) { + const Argument& a = *AI; + const Type* t = a.getType(); + // byval and srets are modelled as pointers (to structs) + if (t->isPointerTy()) { + Type* pointee = dyn_cast<PointerType>(t)->getElementType(); + + if (ByvalRewriteRules && a.hasByValAttr()) { + const TypeRewriteRule* rule = + MatchRewriteRules(pointee, ByvalRewriteRules); + if (rule != 0 && RegUseForRewriteRule(rule) <= available) { + return true; + } + } else if (SretRewriteRules && a.hasStructRetAttr()) { + if (0 != MatchRewriteRules(pointee, SretRewriteRules)) { + return true; + } + } + } + available -= RegUseForType(t); + } + return false; +} + +// Used for sret rewrites to determine the new function result type +Type* GetNewReturnType(Type* type, + const TypeRewriteRule* rule, + LLVMContext& C) { + if (std::string("l") == rule->dst || + std::string("d") == rule->dst) { + return GetElementaryType(rule->dst[0], C); + } else if (rule->dst[0] == 's') { + const char* cp = rule->dst + 2; // skip 's(' + std::vector<Type*> fields; + while (*cp != ')') { + fields.push_back(GetElementaryType(*cp, C)); + ++cp; + } + return StructType::get(C, fields, false /* isPacked */); + } else { + dbgs() << *type << " " << rule->name << "\n"; + llvm_unreachable("unexpected return type"); + return 0; + } +} + +// Rewrite sret parameter while rewriting a function +Type* RewriteFunctionSret(Function& F, + Value* orig_val, + const TypeRewriteRule* rule) { + LLVMContext& C = F.getContext(); + BasicBlock& entry = F.getEntryBlock(); + Instruction* before = &(entry.front()); + Type* old_type = orig_val->getType(); + Type* old_pointee = dyn_cast<PointerType>(old_type)->getElementType(); + Type* new_type = GetNewReturnType(old_type, rule, C); + // create a temporary to hold the return value as we no longer pass + // in the pointer + AllocaInst* tmp_ret = new AllocaInst(old_pointee, "result", before); + orig_val->replaceAllUsesWith(tmp_ret); + CastInst* cast_ret = CastInst::CreatePointerCast( + tmp_ret, + PointerType::getUnqual(new_type), + "byval_cast", + before); + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); + II != IE; + /* see below */) { + Instruction* inst = II; + // we do decontructive magic below, so advance the iterator here + // (this is still a little iffy) + ++II; + ReturnInst* ret = dyn_cast<ReturnInst>(inst); + if (ret) { + if (ret->getReturnValue() != 0) + llvm_unreachable("expected a void return"); + // load the return value from temporary + Value *ret_val = new LoadInst(cast_ret, "load_result", ret); + // return that loaded value and delete the return instruction + ReturnInst::Create(C, ret_val, ret); + ret->eraseFromParent(); + } + } + } + return new_type; +} + +// Rewrite one byval function parameter while rewriting a function +void FixFunctionByvalsParameter(Function& F, + std::vector<Argument*>& new_arguments, + std::vector<Attributes>& new_attributes, + Value* byval, + const TypeRewriteRule* rule) { + LLVMContext& C = F.getContext(); + BasicBlock& entry = F.getEntryBlock(); + Instruction* before = &(entry.front()); + Twine prefix = byval->getName() + "_split"; + Type* t = byval->getType(); + Type* pointee = dyn_cast<PointerType>(t)->getElementType(); + AllocaInst* tmp_param = new AllocaInst(pointee, prefix + "_param", before); + byval->replaceAllUsesWith(tmp_param); + // convert byval poiner to char pointer + Value* base = CastInst::CreatePointerCast( + tmp_param, PointerType::getInt8PtrTy(C), prefix + "_base", before); + + int width = 0; + const char* pattern = rule->dst; + for (int offset = 0; *pattern; ++pattern, offset += width) { + width = GetElementaryTypeWidth(*pattern); + Type* t = GetElementaryType(*pattern, C); + Argument* arg = new Argument(t, prefix, &F); + Type* pt = PointerType::getUnqual(t); + // the code below generates something like: + // <CHAR-PTR> = getelementptr i8* <BASE>, i32 <OFFSET-FROM-BASE> + // <PTR> = bitcast i8* <CHAR-PTR> to <TYPE>* + // store <ARG> <TYPE>* <ELEM-PTR> + ConstantInt* baseOffset = ConstantInt::get(Type::getInt32Ty(C), offset); + Value *v; + v = GetElementPtrInst::Create(base, baseOffset, prefix + "_base_add", before); + v = CastInst::CreatePointerCast(v, pt, prefix + "_cast", before); + v = new StoreInst(arg, v, before); + + new_arguments.push_back(arg); + new_attributes.push_back(Attributes()); + } +} + +// Change function signature to reflect all the rewrites. +// This includes function type/signature and attributes. +void UpdateFunctionSignature(Function &F, + Type* new_result_type, + std::vector<Argument*>& new_arguments, + std::vector<Attributes>& new_attributes) { + DEBUG(dbgs() << "PHASE PROTOTYPE UPDATE\n"); + if (new_result_type) { + DEBUG(dbgs() << "NEW RESULT TYPE: " << *new_result_type << "\n"); + } + // Update function type + FunctionType* old_fun_type = F.getFunctionType(); + std::vector<Type*> new_types; + for (size_t i = 0; i < new_arguments.size(); ++i) { + new_types.push_back(new_arguments[i]->getType()); + } + + FunctionType* new_fun_type = FunctionType::get( + new_result_type ? new_result_type : old_fun_type->getReturnType(), + new_types, + false); + F.setType(PointerType::getUnqual(new_fun_type)); + + Function::ArgumentListType& args = F.getArgumentList(); + DEBUG(dbgs() << "PHASE ARGUMENT DEL " << args.size() << "\n"); + while (args.size()) { + Argument* arg = args.begin(); + DEBUG(dbgs() << "DEL " << arg->getArgNo() << " " << arg->getName() << "\n"); + args.remove(args.begin()); + } + + DEBUG(dbgs() << "PHASE ARGUMENT ADD " << new_arguments.size() << "\n"); + for (size_t i = 0; i < new_arguments.size(); ++i) { + Argument* arg = new_arguments[i]; + DEBUG(dbgs() << "ADD " << i << " " << arg->getName() << "\n"); |