aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms')
-rw-r--r--lib/Transforms/CMakeLists.txt1
-rw-r--r--lib/Transforms/IPO/ExtractGV.cpp18
-rw-r--r--lib/Transforms/InstCombine/InstCombineCalls.cpp6
-rw-r--r--lib/Transforms/LLVMBuild.txt2
-rw-r--r--lib/Transforms/Makefile6
-rw-r--r--lib/Transforms/NaCl/CMakeLists.txt5
-rw-r--r--lib/Transforms/NaCl/ExpandCtors.cpp145
-rw-r--r--lib/Transforms/NaCl/LLVMBuild.txt23
-rw-r--r--lib/Transforms/NaCl/Makefile15
-rw-r--r--lib/Transforms/Scalar/CMakeLists.txt1
-rw-r--r--lib/Transforms/Scalar/NaClCcRewrite.cpp1053
11 files changed, 1270 insertions, 5 deletions
diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt
index de1353e6c1..9fa690971a 100644
--- a/lib/Transforms/CMakeLists.txt
+++ b/lib/Transforms/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(Scalar)
add_subdirectory(IPO)
add_subdirectory(Vectorize)
add_subdirectory(Hello)
+add_subdirectory(NaCl)
diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp
index 4c7f0ed236..b2748f2e6c 100644
--- a/lib/Transforms/IPO/ExtractGV.cpp
+++ b/lib/Transforms/IPO/ExtractGV.cpp
@@ -58,6 +58,15 @@ namespace {
continue;
if (I->getName() == "llvm.global_ctors")
continue;
+ // @LOCALMOD-BEGIN - this is likely upstreamable
+ // Note: there will likely be more cases once this
+ // is exercises more thorougly.
+ if (I->getName() == "llvm.global_dtors")
+ continue;
+ // not observed yet
+ if (I->hasExternalWeakLinkage())
+ continue;
+ // @LOCALMOD-END
}
if (I->hasLocalLinkage())
@@ -72,8 +81,15 @@ namespace {
} else {
if (I->hasAvailableExternallyLinkage())
continue;
+ // @LOCALMOD-BEGIN - this is likely upstreamable
+ // Note: there will likely be more cases once this
+ // is exercises more thorougly.
+ // observed for pthread_cancel
+ if (I->hasExternalWeakLinkage())
+ continue;
+ // @LOCALMOD-END
}
-
+
if (I->hasLocalLinkage())
I->setVisibility(GlobalValue::HiddenVisibility);
I->setLinkage(GlobalValue::ExternalLinkage);
diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 359bc488f3..0958842d08 100644
--- a/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1148,8 +1148,10 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// If we are removing arguments to the function, emit an obnoxious warning.
if (FT->getNumParams() < NumActualArgs) {
if (!FT->isVarArg()) {
- errs() << "WARNING: While resolving call to function '"
- << Callee->getName() << "' arguments were dropped!\n";
+ if (Callee->getName() != "main") { // @LOCALMOD
+ errs() << "WARNING: While resolving call to function '"
+ << Callee->getName() << "' arguments were dropped!\n";
+ }
} else {
// Add all of the arguments in their promoted form to the arg list.
for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) {
diff --git a/lib/Transforms/LLVMBuild.txt b/lib/Transforms/LLVMBuild.txt
index f7bca064c7..001ba5d232 100644
--- a/lib/Transforms/LLVMBuild.txt
+++ b/lib/Transforms/LLVMBuild.txt
@@ -16,7 +16,7 @@
;===------------------------------------------------------------------------===;
[common]
-subdirectories = IPO InstCombine Instrumentation Scalar Utils Vectorize
+subdirectories = IPO InstCombine Instrumentation Scalar Utils Vectorize NaCl
[component_0]
type = Group
diff --git a/lib/Transforms/Makefile b/lib/Transforms/Makefile
index 8b1df92fa2..ae03ff32c5 100644
--- a/lib/Transforms/Makefile
+++ b/lib/Transforms/Makefile
@@ -8,7 +8,11 @@
##===----------------------------------------------------------------------===##
LEVEL = ../..
-PARALLEL_DIRS = Utils Instrumentation Scalar InstCombine IPO Vectorize Hello
+PARALLEL_DIRS = Utils Instrumentation Scalar InstCombine IPO Vectorize Hello NaCl
+
+ifeq ($(NACL_SANDBOX),1)
+ PARALLEL_DIRS := $(filter-out Hello, $(PARALLEL_DIRS))
+endif
include $(LEVEL)/Makefile.config
diff --git a/lib/Transforms/NaCl/CMakeLists.txt b/lib/Transforms/NaCl/CMakeLists.txt
new file mode 100644
index 0000000000..d634ad9655
--- /dev/null
+++ b/lib/Transforms/NaCl/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_llvm_library(LLVMTransformsNaCl
+ ExpandCtors.cpp
+ )
+
+add_dependencies(LLVMTransformsNaCl intrinsics_gen)
diff --git a/lib/Transforms/NaCl/ExpandCtors.cpp b/lib/Transforms/NaCl/ExpandCtors.cpp
new file mode 100644
index 0000000000..6b8130e4fb
--- /dev/null
+++ b/lib/Transforms/NaCl/ExpandCtors.cpp
@@ -0,0 +1,145 @@
+//===- ExpandCtors.cpp - Convert ctors/dtors to concrete arrays -----------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts LLVM's special symbols llvm.global_ctors and
+// llvm.global_dtors to concrete arrays, __init_array_start/end and
+// __fini_array_start/end, that are usable by a C library.
+//
+// This pass sorts the contents of global_ctors/dtors according to the
+// priority values they contain and removes the priority values.
+//
+//===----------------------------------------------------------------------===//
+
+#include <vector>
+
+#include "llvm/Constants.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Instructions.h"
+#include "llvm/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/NaCl.h"
+#include "llvm/TypeBuilder.h"
+
+using namespace llvm;
+
+namespace {
+ struct ExpandCtors : public ModulePass {
+ static char ID; // Pass identification, replacement for typeid
+ ExpandCtors() : ModulePass(ID) {
+ initializeExpandCtorsPass(*PassRegistry::getPassRegistry());
+ }
+
+ virtual bool runOnModule(Module &M);
+ };
+}
+
+char ExpandCtors::ID = 0;
+INITIALIZE_PASS(ExpandCtors, "nacl-expand-ctors",
+ "Hook up constructor and destructor arrays to libc",
+ false, false)
+
+static void setGlobalVariableValue(Module &M, const char *Name,
+ Constant *Value) {
+ GlobalVariable *Var = M.getNamedGlobal(Name);
+ if (!Var) {
+ // This warning can happen in a program that does not use a libc
+ // and so does not call the functions in __init_array_start or
+ // __fini_array_end. Such a program might be linked with
+ // "-nostdlib".
+ errs() << "Warning: Variable " << Name << " not referenced\n";
+ } else {
+ if (Var->hasInitializer()) {
+ report_fatal_error(std::string("Variable ") + Name +
+ " already has an initializer");
+ }
+ Var->replaceAllUsesWith(ConstantExpr::getBitCast(Value, Var->getType()));
+ Var->eraseFromParent();
+ }
+}
+
+struct FuncArrayEntry {
+ uint64_t priority;
+ Constant *func;
+};
+
+static bool compareEntries(FuncArrayEntry Entry1, FuncArrayEntry Entry2) {
+ return Entry1.priority < Entry2.priority;
+}
+
+static void defineFuncArray(Module &M, const char *LlvmArrayName,
+ const char *StartSymbol,
+ const char *EndSymbol) {
+ std::vector<Constant*> Funcs;
+
+ GlobalVariable *Array = M.getNamedGlobal(LlvmArrayName);
+ if (Array) {
+ if (Array->hasInitializer() && !Array->getInitializer()->isNullValue()) {
+ ConstantArray *InitList = cast<ConstantArray>(Array->getInitializer());
+ std::vector<FuncArrayEntry> FuncsToSort;
+ for (unsigned Index = 0; Index < InitList->getNumOperands(); ++Index) {
+ ConstantStruct *CS = cast<ConstantStruct>(InitList->getOperand(Index));
+ FuncArrayEntry Entry;
+ Entry.priority = cast<ConstantInt>(CS->getOperand(0))->getZExtValue();
+ Entry.func = CS->getOperand(1);
+ FuncsToSort.push_back(Entry);
+ }
+
+ std::sort(FuncsToSort.begin(), FuncsToSort.end(), compareEntries);
+ for (std::vector<FuncArrayEntry>::iterator Iter = FuncsToSort.begin();
+ Iter != FuncsToSort.end();
+ ++Iter) {
+ Funcs.push_back(Iter->func);
+ }
+ }
+ // No code should be referencing global_ctors/global_dtors,
+ // because this symbol is internal to LLVM.
+ Array->eraseFromParent();
+ }
+
+ Type *FuncTy = FunctionType::get(Type::getVoidTy(M.getContext()), false);
+ Type *FuncPtrTy = FuncTy->getPointerTo();
+ ArrayType *ArrayTy = ArrayType::get(FuncPtrTy, Funcs.size());
+ GlobalVariable *NewArray =
+ new GlobalVariable(M, ArrayTy, /* isConstant= */ true,
+ GlobalValue::InternalLinkage,
+ ConstantArray::get(ArrayTy, Funcs));
+ setGlobalVariableValue(M, StartSymbol, NewArray);
+ // We do this last so that LLVM gives NewArray the name
+ // "__{init,fini}_array_start" without adding any suffixes to
+ // disambiguate from the original GlobalVariable's name. This is
+ // not essential -- it just makes the output easier to understand
+ // when looking at symbols for debugging.
+ NewArray->setName(StartSymbol);
+
+ // We replace "__{init,fini}_array_end" with the address of the end
+ // of NewArray. This removes the name "__{init,fini}_array_end"
+ // from the output, which is not ideal for debugging. Ideally we
+ // would convert "__{init,fini}_array_end" to being a GlobalAlias
+ // that points to the end of the array. However, unfortunately LLVM
+ // does not generate correct code when a GlobalAlias contains a
+ // GetElementPtr ConstantExpr.
+ Constant *NewArrayEnd =
+ ConstantExpr::getGetElementPtr(NewArray,
+ ConstantInt::get(M.getContext(),
+ APInt(32, 1)));
+ setGlobalVariableValue(M, EndSymbol, NewArrayEnd);
+}
+
+bool ExpandCtors::runOnModule(Module &M) {
+ defineFuncArray(M, "llvm.global_ctors",
+ "__init_array_start", "__init_array_end");
+ defineFuncArray(M, "llvm.global_dtors",
+ "__fini_array_start", "__fini_array_end");
+ return true;
+}
+
+ModulePass *llvm::createExpandCtorsPass() {
+ return new ExpandCtors();
+}
diff --git a/lib/Transforms/NaCl/LLVMBuild.txt b/lib/Transforms/NaCl/LLVMBuild.txt
new file mode 100644
index 0000000000..2f1522b3e5
--- /dev/null
+++ b/lib/Transforms/NaCl/LLVMBuild.txt
@@ -0,0 +1,23 @@
+;===- ./lib/Transforms/NaCl/LLVMBuild.txt ----------------------*- Conf -*--===;
+;
+; The LLVM Compiler Infrastructure
+;
+; This file is distributed under the University of Illinois Open Source
+; License. See LICENSE.TXT for details.
+;
+;===------------------------------------------------------------------------===;
+;
+; This is an LLVMBuild description file for the components in this subdirectory.
+;
+; For more information on the LLVMBuild system, please see:
+;
+; http://llvm.org/docs/LLVMBuild.html
+;
+;===------------------------------------------------------------------------===;
+
+[component_0]
+type = Library
+name = NaCl
+parent = Transforms
+library_name = NaCl
+required_libraries = Core
diff --git a/lib/Transforms/NaCl/Makefile b/lib/Transforms/NaCl/Makefile
new file mode 100644
index 0000000000..ecf8db6eae
--- /dev/null
+++ b/lib/Transforms/NaCl/Makefile
@@ -0,0 +1,15 @@
+##===- lib/Transforms/NaCl/Makefile-------------------------*- Makefile -*-===##
+#
+# The LLVM Compiler Infrastructure
+#
+# This file is distributed under the University of Illinois Open Source
+# License. See LICENSE.TXT for details.
+#
+##===----------------------------------------------------------------------===##
+
+LEVEL = ../../..
+LIBRARYNAME = LLVMTransformsNaCl
+BUILD_ARCHIVE = 1
+
+include $(LEVEL)/Makefile.common
+
diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt
index b3fc6e338c..06ef4b4a9b 100644
--- a/lib/Transforms/Scalar/CMakeLists.txt
+++ b/lib/Transforms/Scalar/CMakeLists.txt
@@ -32,6 +32,7 @@ add_llvm_library(LLVMScalarOpts
SimplifyLibCalls.cpp
Sink.cpp
TailRecursionElimination.cpp
+ NaClCcRewrite.cpp
)
add_dependencies(LLVMScalarOpts intrinsics_gen)
diff --git a/lib/Transforms/Scalar/NaClCcRewrite.cpp b/lib/Transforms/Scalar/NaClCcRewrite.cpp
new file mode 100644
index 0000000000..5eace7f39d
--- /dev/null
+++ b/lib/Transforms/Scalar/NaClCcRewrite.cpp
@@ -0,0 +1,1053 @@
+//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements calling convention rewrite for Native Client to ensure
+// compatibility between pnacl and gcc generated code when calling
+// ppapi interface functions.
+//===----------------------------------------------------------------------===//
+
+
+// Major TODOs:
+// * dealing with vararg
+// (We shoulf exclude all var arg functions and calls to them from rewrites)
+
+#define DEBUG_TYPE "naclcc"
+
+#include "llvm/Argument.h"
+#include "llvm/Attributes.h"
+#include "llvm/Constant.h"
+#include "llvm/DataLayout.h"
+#include "llvm/Instruction.h"
+#include "llvm/Instructions.h"
+#include "llvm/IntrinsicInst.h"
+#include "llvm/Function.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/InstIterator.h"
+#include "llvm/Target/TargetLibraryInfo.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/TargetLoweringObjectFile.h"
+#include "llvm/Transforms/Scalar.h"
+
+#include <vector>
+
+using namespace llvm;
+
+namespace llvm {
+
+cl::opt<bool> FlagEnableCcRewrite(
+ "nacl-cc-rewrite",
+ cl::desc("enable NaCl CC rewrite"));
+}
+
+namespace {
+
+// This represents a rule for rewiriting types
+struct TypeRewriteRule {
+ const char* src; // type pattern we are trying to match
+ const char* dst; // replacement type
+ const char* name; // name of the rule for diagnosis
+};
+
+// Note: all rules must be well-formed
+// * parentheses must match
+// * TODO: add verification for this
+
+// Legend:
+// s(): struct (also used for unions)
+// c: char (= 8 bit int) (only allowed for src)
+// i: 32 bit int
+// l: 64 bit int
+// f: 32 bit float
+// d: 64 bit float (= double)
+// p: untyped pointer (only allowed for src)
+// P(): typed pointer (currently not used, only allowed for src)
+// F: generic function type (only allowed for src)
+
+// The X8664 Rewrite rules are also subject to
+// register constraints, c.f.: section 3.2.3
+// http://www.x86-64.org/documentation/abi.pdf
+// (roughly) for X8664: up to 2 regs per struct can be used for struct passsing
+// and up to 2 regs for struct returns
+// The rewrite rules are straight forward except for: s(iis(d)) => ll
+// which would be straight forward if the frontend had lowered the union inside
+// of PP_Var to s(l) instead of s(d), yielding: s(iis(l)) => ll
+TypeRewriteRule ByvalRulesX8664[] = {
+ {"s(iis(d))", "ll", "PP_Var"},
+ {"s(pp)", "l", "PP_ArrayOutput"},
+ {"s(ppi)", "li", "PP_CompletionCallback"},
+ {0, 0, 0},
+};
+
+TypeRewriteRule SretRulesX8664[] = {
+ // Note: for srets, multireg returns are modeled as struct returns
+ {"s(iis(d))", "s(ll)", "PP_Var"},
+ {"s(ff)", "d", "PP_FloatPoint"},
+ {"s(ii)", "l", "PP_Point" },
+ {"s(pp)", "l", "PP_ArrayOutput"},
+ {0, 0, 0},
+};
+
+// for ARM: up to 4 regs can be used for struct passsing
+// and up to 2 float regs for struct returns
+TypeRewriteRule ByvalRulesARM[] = {
+ {"s(iis(d))", "ll", "PP_Var"},
+ {"s(ppi)", "iii", "PP_CompletionCallback" },
+ {"s(pp)", "ii", "PP_ArrayOutput"},
+ {0, 0, 0},
+};
+
+TypeRewriteRule SretRulesARM[] = {
+ // Note: for srets, multireg returns are modeled as struct returns
+ {"s(ff)", "s(ff)", "PP_FloatPoint"},
+ {0, 0, 0},
+};
+
+// Helper class to model Register Usage as required by
+// the x86-64 calling conventions
+class RegUse {
+ uint32_t n_int_;
+ uint32_t n_float_;
+
+ public:
+ RegUse(uint32_t n_int=0, uint32_t n_float=0) :
+ n_int_(n_int), n_float_(n_float) {}
+
+ static RegUse OneIntReg() { return RegUse(1, 0); }
+ static RegUse OnePointerReg() { return RegUse(1, 0); }
+ static RegUse OneFloatReg() { return RegUse(0, 1); }
+
+ RegUse operator+(RegUse other) const {
+ return RegUse(n_int_ + other.n_int_, n_float_ + other.n_float_); }
+ RegUse operator-(RegUse other) const {
+ return RegUse(n_int_ - other.n_int_, n_float_ - other.n_float_); }
+ bool operator==(RegUse other) const {
+ return n_int_ == other.n_int_ && n_float_ == other.n_float_; }
+ bool operator!=(RegUse other) const {
+ return n_int_ != other.n_int_ && n_float_ != other.n_float_; }
+ bool operator<=(RegUse other) const {
+ return n_int_ <= other.n_int_ && n_float_ <= other.n_float_; }
+ bool operator<(RegUse other) const {
+ return n_int_ < other.n_int_ && n_float_ < other.n_float_; }
+ bool operator>=(RegUse other) const {
+ return n_int_ >= other.n_int_ && n_float_ >= other.n_float_; }
+ bool operator>(RegUse other) const {
+ return n_int_ > other.n_int_ && n_float_ > other.n_float_; }
+ RegUse& operator+=(const RegUse& other) {
+ n_int_ += other.n_int_; n_float_ += other.n_float_; return *this;}
+ RegUse& operator-=(const RegUse& other) {
+ n_int_ -= other.n_int_; n_float_ -= other.n_float_; return *this;}
+
+ friend raw_ostream& operator<<(raw_ostream &O, const RegUse& reg);
+};
+
+raw_ostream& operator<<(raw_ostream &O, const RegUse& reg) {
+ O << "(" << reg.n_int_ << ", " << reg.n_float_ << ")";
+ return O;
+}
+
+// TODO: Find a better way to determine the architecture
+const TypeRewriteRule* GetByvalRewriteRulesForTarget(
+ const TargetLowering* tli) {
+ if (!FlagEnableCcRewrite) return 0;
+
+ const TargetMachine &m = tli->getTargetMachine();
+ const StringRef triple = m.getTargetTriple();
+
+ if (0 == triple.find("x86_64")) return ByvalRulesX8664;
+ if (0 == triple.find("i686")) return 0;
+ if (0 == triple.find("armv7a")) return ByvalRulesARM;
+
+ llvm_unreachable("Unknown arch");
+ return 0;
+}
+
+// TODO: Find a better way to determine the architecture
+const TypeRewriteRule* GetSretRewriteRulesForTarget(
+ const TargetLowering* tli) {
+ if (!FlagEnableCcRewrite) return 0;
+
+ const TargetMachine &m = tli->getTargetMachine();
+ const StringRef triple = m.getTargetTriple();
+
+ if (0 == triple.find("x86_64")) return SretRulesX8664;
+ if (0 == triple.find("i686")) return 0;
+ if (0 == triple.find("armv7a")) return SretRulesARM;
+
+ llvm_unreachable("Unknown arch");
+ return 0;
+}
+
+// TODO: Find a better way to determine the architecture
+// Describes the number of registers available for function
+// argument passing which may affect rewrite decisions on
+// some platforms.
+RegUse GetAvailableRegsForTarget(
+ const TargetLowering* tli) {
+ if (!FlagEnableCcRewrite) return RegUse(0, 0);
+
+ const TargetMachine &m = tli->getTargetMachine();
+ const StringRef triple = m.getTargetTriple();
+
+ // integer: RDI, RSI, RDX, RCX, R8, R9
+ // float XMM0, ..., XMM7
+ if (0 == triple.find("x86_64")) return RegUse(6, 8);
+ // unused
+ if (0 == triple.find("i686")) return RegUse(0, 0);
+ // no constraints enforced here - the backend handles all the details
+ uint32_t max = std::numeric_limits<uint32_t>::max();
+ if (0 == triple.find("armv7a")) return RegUse(max, max);
+
+ llvm_unreachable("Unknown arch");
+ return 0;
+}
+
+// This class represents the a bitcode rewrite pass which ensures
+// that all ppapi interfaces are calling convention compatible
+// with gcc. This pass is archtitecture dependent.
+struct NaClCcRewrite : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ const TypeRewriteRule* SretRewriteRules;
+ const TypeRewriteRule* ByvalRewriteRules;
+ const RegUse AvailableRegs;
+
+ explicit NaClCcRewrite(const TargetLowering *tli = 0)
+ : FunctionPass(ID),
+ SretRewriteRules(GetSretRewriteRulesForTarget(tli)),
+ ByvalRewriteRules(GetByvalRewriteRulesForTarget(tli)),
+ AvailableRegs(GetAvailableRegsForTarget(tli)) {
+ initializeNaClCcRewritePass(*PassRegistry::getPassRegistry());
+ }
+
+ // main pass entry point
+ bool runOnFunction(Function &F);
+
+ private:
+ void RewriteCallsite(Instruction* call, LLVMContext& C);
+ void RewriteFunctionPrologAndEpilog(Function& F);
+};
+
+char NaClCcRewrite::ID = 0;
+
+// This is only used for dst side of rules
+Type* GetElementaryType(char c, LLVMContext& C) {
+ switch (c) {
+ case 'i':
+ return Type::getInt32Ty(C);
+ case 'l':
+ return Type::getInt64Ty(C);
+ case 'd':
+ return Type::getDoubleTy(C);
+ case 'f':
+ return Type::getFloatTy(C);
+ default:
+ dbgs() << c << "\n";
+ llvm_unreachable("Unknown type specifier");
+ return 0;
+ }
+}
+
+// This is only used for the dst side of a rule
+int GetElementaryTypeWidth(char c) {
+ switch (c) {
+ case 'i':
+ case 'f':
+ return 4;
+ case 'l':
+ case 'd':
+ return 8;
+ default:
+ llvm_unreachable("Unknown type specifier");
+ return 0;
+ }
+}
+
+// Check whether a type matches the *src* side pattern of a rewrite rule.
+// Note that the pattern parameter is updated during the recursion
+bool HasRewriteType(const Type* type, const char*& pattern) {
+ switch (*pattern++) {
+ case '\0':
+ return false;
+ case ')':
+ return false;
+ case 's': // struct and union are currently no distinguished
+ {
+ if (*pattern++ != '(') llvm_unreachable("malformed type pattern");
+ if (!type->isStructTy()) return false;
+ // check struct members
+ const StructType* st = cast<StructType>(type);
+ for (StructType::element_iterator it = st->element_begin(),
+ end = st->element_end();
+ it != end;
+ ++it) {
+ if (!HasRewriteType(*it, pattern)) return false;
+ }
+ // ensure we reached the end
+ int c = *pattern++;
+ return c == ')';
+ }
+ break;
+ case 'c':
+ return type->isIntegerTy(8);
+ case 'i':
+ return type->isIntegerTy(32);
+ case 'l':
+ return type->isIntegerTy(64);
+ case 'd':
+ return type->isDoubleTy();
+ case 'f':
+ return type->isFloatTy();
+ case 'F':
+ return type->isFunctionTy();
+ case 'p': // untyped pointer
+ return type->isPointerTy();
+ case 'P': // typed pointer
+ {
+ if (*pattern++ != '(') llvm_unreachable("malformed type pattern");
+ if (!type->isPointerTy()) return false;
+ Type* pointee = dyn_cast<PointerType>(type)->getElementType();
+ if (!HasRewriteType(pointee, pattern)) return false;
+ int c = *pattern++;
+ return c == ')';
+ }
+ default:
+ llvm_unreachable("Unknown type specifier");
+ return false;
+ }
+}
+
+RegUse RegUseForRewriteRule(const TypeRewriteRule* rule) {
+ const char* pattern = std::string("C") == rule->dst ? rule->src : rule->dst;
+ RegUse result(0, 0);
+ while (char c = *pattern++) {
+ // Note, we only support a subset here, complex types (s, P)
+ // would require more work
+ switch (c) {
+ case 'i':
+ case 'l':
+ result += RegUse::OneIntReg();
+ break;
+ case 'd':
+ case 'f':
+ result += RegUse::OneFloatReg();
+ break;
+ default:
+ dbgs() << c << "\n";
+ llvm_unreachable("unexpected return type");
+ }
+ }
+ return result;
+}
+
+// Note, this only has to be accurate for x86-64 and is intentionally
+// quite strict so that we know when to add support for new types.
+// Ideally, unexpected types would be flagged by a bitcode checker.
+RegUse RegUseForType(const Type* t) {
+ if (t->isPointerTy()) {
+ return RegUse::OnePointerReg();
+ } else if (t->isFloatTy() || t->isDoubleTy()) {
+ return RegUse::OneFloatReg();
+ } else if (t->isIntegerTy()) {
+ const IntegerType* it = dyn_cast<const IntegerType>(t);
+ unsigned width = it->getBitWidth();
+ // x86-64 assumption here - use "register info" to make this better
+ if (width <= 64) return RegUse::OneIntReg();
+ }
+
+ dbgs() << *const_cast<Type*>(t) << "\n";
+ llvm_unreachable("unexpected type in RegUseForType");
+}
+
+// Match a type against a set of rewrite rules.
+// Return the matching rule, if any.
+const TypeRewriteRule* MatchRewriteRules(
+ const Type* type, const TypeRewriteRule* rules) {
+ if (rules == 0) return 0;
+ for (; rules->name != 0; ++rules) {
+ const char* pattern = rules->src;
+ if (HasRewriteType(type, pattern)) return rules;
+ }
+ return 0;
+}
+
+// Same as MatchRewriteRules but "dereference" type first.
+const TypeRewriteRule* MatchRewriteRulesPointee(const Type* t,
+ const TypeRewriteRule* Rules) {
+ // sret and byval are both modelled as pointers
+ const PointerType* pointer = dyn_cast<PointerType>(t);
+ if (pointer == 0) return 0;
+
+ return MatchRewriteRules(pointer->getElementType(), Rules);
+}
+
+// Note, the attributes are not part of the type but are stored
+// with the CallInst and/or the Function (if any)
+Type* CreateFunctionPointerType(Type* result_type,
+ std::vector<Type*>& arguments) {
+ FunctionType* ft = FunctionType::get(result_type,
+ arguments,
+ false);
+ return PointerType::getUnqual(ft);
+}
+
+// Determines whether a function body needs a rewrite
+bool FunctionNeedsRewrite(const Function* fun,
+ const TypeRewriteRule* ByvalRewriteRules,
+ const TypeRewriteRule* SretRewriteRules,
+ RegUse available) {
+ // TODO: can this be detected on indirect callsites as well.
+ // if we skip the rewrite for the function body
+ // we also need to skip it at the callsites
+ // if (F.isVarArg()) return false;
+
+ // Vectors and Arrays are not supported for compatibility
+ for (Function::const_arg_iterator AI = fun->arg_begin(), AE = fun->arg_end();
+ AI != AE;
+ ++AI) {
+ const Type* t = AI->getType();
+ if (isa<VectorType>(t) || isa<ArrayType>(t)) return false;
+ }
+
+ for (Function::const_arg_iterator AI = fun->arg_begin(), AE = fun->arg_end();
+ AI != AE;
+ ++AI) {
+ const Argument& a = *AI;
+ const Type* t = a.getType();
+ // byval and srets are modelled as pointers (to structs)
+ if (t->isPointerTy()) {
+ Type* pointee = dyn_cast<PointerType>(t)->getElementType();
+
+ if (ByvalRewriteRules && a.hasByValAttr()) {
+ const TypeRewriteRule* rule =
+ MatchRewriteRules(pointee, ByvalRewriteRules);
+ if (rule != 0 && RegUseForRewriteRule(rule) <= available) {
+ return true;
+ }
+ } else if (SretRewriteRules && a.hasStructRetAttr()) {
+ if (0 != MatchRewriteRules(pointee, SretRewriteRules)) {
+ return true;
+ }
+ }
+ }
+ available -= RegUseForType(t);
+ }
+ return false;
+}
+
+// Used for sret rewrites to determine the new function result type
+Type* GetNewReturnType(Type* type,
+ const TypeRewriteRule* rule,
+ LLVMContext& C) {
+ if (std::string("l") == rule->dst ||
+ std::string("d") == rule->dst) {
+ return GetElementaryType(rule->dst[0], C);
+ } else if (rule->dst[0] == 's') {
+ const char* cp = rule->dst + 2; // skip 's('
+ std::vector<Type*> fields;
+ while (*cp != ')') {
+ fields.push_back(GetElementaryType(*cp, C));
+ ++cp;
+ }
+ return StructType::get(C, fields, false /* isPacked */);
+ } else {
+ dbgs() << *type << " " << rule->name << "\n";
+ llvm_unreachable("unexpected return type");
+ return 0;
+ }
+}
+
+// Rewrite sret parameter while rewriting a function
+Type* RewriteFunctionSret(Function& F,
+ Value* orig_val,
+ const TypeRewriteRule* rule) {
+ LLVMContext& C = F.getContext();
+ BasicBlock& entry = F.getEntryBlock();
+ Instruction* before = &(entry.front());
+ Type* old_type = orig_val->getType();
+ Type* old_pointee = dyn_cast<PointerType>(old_type)->getElementType();
+ Type* new_type = GetNewReturnType(old_type, rule, C);
+ // create a temporary to hold the return value as we no longer pass
+ // in the pointer
+ AllocaInst* tmp_ret = new AllocaInst(old_pointee, "result", before);
+ orig_val->replaceAllUsesWith(tmp_ret);
+ CastInst* cast_ret = CastInst::CreatePointerCast(
+ tmp_ret,
+ PointerType::getUnqual(new_type),
+ "byval_cast",
+ before);
+ for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
+ for (BasicBlock::iterator II = BI->begin(), IE = BI->end();
+ II != IE;
+ /* see below */) {
+ Instruction* inst = II;
+ // we do decontructive magic below, so advance the iterator here
+ // (this is still a little iffy)
+ ++II;
+ ReturnInst* ret = dyn_cast<ReturnInst>(inst);
+ if (ret) {
+ if (ret->getReturnValue() != 0)
+ llvm_unreachable("expected a void return");
+ // load the return value from temporary
+ Value *ret_val = new LoadInst(cast_ret, "load_result", ret);
+ // return that loaded value and delete the return instruction
+ ReturnInst::Create(C, ret_val, ret);
+ ret->eraseFromParent();
+ }
+ }
+ }
+ return new_type;
+}
+
+// Rewrite one byval function parameter while rewriting a function
+void FixFunctionByvalsParameter(Function& F,
+ std::vector<Argument*>& new_arguments,
+ std::vector<Attributes>& new_attributes,
+ Value* byval,
+ const TypeRewriteRule* rule) {
+ LLVMContext& C = F.getContext();
+ BasicBlock& entry = F.getEntryBlock();
+ Instruction* before = &(entry.front());
+ Twine prefix = byval->getName() + "_split";
+ Type* t = byval->getType();
+ Type* pointee = dyn_cast<PointerType>(t)->getElementType();
+ AllocaInst* tmp_param = new AllocaInst(pointee, prefix + "_param", before);
+ byval->replaceAllUsesWith(tmp_param);
+ // convert byval poiner to char pointer
+ Value* base = CastInst::CreatePointerCast(
+ tmp_param, PointerType::getInt8PtrTy(C), prefix + "_base", before);
+
+ int width = 0;
+ const char* pattern = rule->dst;
+ for (int offset = 0; *pattern; ++pattern, offset += width) {
+ width = GetElementaryTypeWidth(*pattern);
+ Type* t = GetElementaryType(*pattern, C);
+ Argument* arg = new Argument(t, prefix, &F);
+ Type* pt = PointerType::getUnqual(t);
+ // the code below generates something like:
+ // <CHAR-PTR> = getelementptr i8* <BASE>, i32 <OFFSET-FROM-BASE>
+ // <PTR> = bitcast i8* <CHAR-PTR> to <TYPE>*
+ // store <ARG> <TYPE>* <ELEM-PTR>
+ ConstantInt* baseOffset = ConstantInt::get(Type::getInt32Ty(C), offset);
+ Value *v;
+ v = GetElementPtrInst::Create(base, baseOffset, prefix + "_base_add", before);
+ v = CastInst::CreatePointerCast(v, pt, prefix + "_cast", before);
+ v = new StoreInst(arg, v, before);
+
+ new_arguments.push_back(arg);
+ new_attributes.push_back(Attributes());
+ }
+}
+
+// Change function signature to reflect all the rewrites.
+// This includes function type/signature and attributes.
+void UpdateFunctionSignature(Function &F,
+ Type* new_result_type,
+ std::vector<Argument*>& new_arguments,
+ std::vector<Attributes>& new_attributes) {
+ DEBUG(dbgs() << "PHASE PROTOTYPE UPDATE\n");
+ if (new_result_type) {
+ DEBUG(dbgs() << "NEW RESULT TYPE: " << *new_result_type << "\n");
+ }
+ // Update function type
+ FunctionType* old_fun_type = F.getFunctionType();
+ std::vector<Type*> new_types;
+ for (size_t i = 0; i < new_arguments.size(); ++i) {
+ new_types.push_back(new_arguments[i]->getType());
+ }
+
+ FunctionType* new_fun_type = FunctionType::get(
+ new_result_type ? new_result_type : old_fun_type->getReturnType(),
+ new_types,
+ false);
+ F.setType(PointerType::getUnqual(new_fun_type));
+
+ Function::ArgumentListType& args = F.getArgumentList();
+ DEBUG(dbgs() << "PHASE ARGUMENT DEL " << args.size() << "\n");
+ while (args.size()) {
+ Argument* arg = args.begin();
+ DEBUG(dbgs() << "DEL " << arg->getArgNo() << " " << arg->getName() << "\n");
+ args.remove(args.begin());
+ }
+
+ DEBUG(dbgs() << "PHASE ARGUMENT ADD " << new_arguments.size() << "\n");
+ for (size_t i = 0; i < new_arguments.size(); ++i) {
+ Argument* arg = new_arguments[i];
+ DEBUG(dbgs() << "ADD " << i << " " << arg->getName() << "\n");