aboutsummaryrefslogtreecommitdiff
path: root/lib/Target/PTX
diff options
context:
space:
mode:
authorDan Bailey <dan@dneg.com>2011-11-11 14:45:12 +0000
committerDan Bailey <dan@dneg.com>2011-11-11 14:45:12 +0000
commit96e6458903ab0799542365cac98653c207984162 (patch)
tree814e117c707bad049894f165f2f72f07a11028f8 /lib/Target/PTX
parentb812ee6d7841a617df15aa5d03c8994f223af860 (diff)
allow non-device function calls in PTX when natively handling device-side printf
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144388 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/PTX')
-rw-r--r--lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp18
-rw-r--r--lib/Target/PTX/PTXAsmPrinter.cpp30
-rw-r--r--lib/Target/PTX/PTXAsmPrinter.h2
-rw-r--r--lib/Target/PTX/PTXISelLowering.cpp102
4 files changed, 129 insertions, 23 deletions
diff --git a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp
index aabb404dad..2f6c92d11c 100644
--- a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp
+++ b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp
@@ -96,9 +96,23 @@ void PTXInstPrinter::printCall(const MCInst *MI, raw_ostream &O) {
O << "), ";
}
- O << *(MI->getOperand(Index++).getExpr()) << ", (";
-
+ const MCExpr* Expr = MI->getOperand(Index++).getExpr();
unsigned NumArgs = MI->getOperand(Index++).getImm();
+
+ // if the function call is to printf or puts, change to vprintf
+ if (const MCSymbolRefExpr *SymRefExpr = dyn_cast<MCSymbolRefExpr>(Expr)) {
+ const MCSymbol &Sym = SymRefExpr->getSymbol();
+ if (Sym.getName() == "printf" || Sym.getName() == "puts") {
+ O << "vprintf";
+ } else {
+ O << Sym.getName();
+ }
+ } else {
+ O << *Expr;
+ }
+
+ O << ", (";
+
if (NumArgs > 0) {
printOperand(MI, Index++, O);
for (unsigned i = 1; i < NumArgs; ++i) {
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index 45a6afc858..bdf238b1b0 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -165,6 +165,11 @@ void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
OutStreamer.AddBlankLine();
+ // declare external functions
+ for (Module::const_iterator i = M.begin(), e = M.end();
+ i != e; ++i)
+ EmitFunctionDeclaration(i);
+
// declare global variables
for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
i != e; ++i)
@@ -454,6 +459,31 @@ void PTXAsmPrinter::EmitFunctionEntryLabel() {
OutStreamer.EmitRawText(os.str());
}
+void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func)
+{
+ const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+
+ std::string decl = "";
+
+ // hard-coded emission of extern vprintf function
+
+ if (func->getName() == "printf" || func->getName() == "puts") {
+ decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b";
+ if (ST.is64Bit())
+ decl += "64";
+ else
+ decl += "32";
+ decl += " __param_2, .param .b";
+ if (ST.is64Bit())
+ decl += "64";
+ else
+ decl += "32";
+ decl += " __param_3)\n";
+ }
+
+ OutStreamer.EmitRawText(Twine(decl));
+}
+
unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
StringRef DirName) {
// If FE did not provide a file name, then assume stdin.
diff --git a/lib/Target/PTX/PTXAsmPrinter.h b/lib/Target/PTX/PTXAsmPrinter.h
index 538c0802a2..d5ea4dbc59 100644
--- a/lib/Target/PTX/PTXAsmPrinter.h
+++ b/lib/Target/PTX/PTXAsmPrinter.h
@@ -47,7 +47,7 @@ public:
private:
void EmitVariableDeclaration(const GlobalVariable *gv);
- void EmitFunctionDeclaration();
+ void EmitFunctionDeclaration(const Function* func);
StringMap<unsigned> SourceIdMap;
}; // class PTXAsmPrinter
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index 3307d91a61..7f55871f63 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
@@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
SmallVectorImpl<SDValue> &InVals) const {
MachineFunction& MF = DAG.getMachineFunction();
- PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
- PTXParamManager &PM = MFI->getParamManager();
-
+ PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
+ PTXParamManager &PM = PTXMFI->getParamManager();
+ MachineFrameInfo *MFI = MF.getFrameInfo();
+
assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
"Calls are not handled for the target device");
+ // Identify the callee function
+ const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
+ const Function *function = cast<Function>(GV);
+
+ // allow non-device calls only for printf
+ bool isPrintf = function->getName() == "printf" || function->getName() == "puts";
+
+ assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
+ "PTX function calls must be to PTX device functions");
+
+ unsigned outSize = isPrintf ? 2 : Outs.size();
+
std::vector<SDValue> Ops;
// The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
- Ops.resize(Outs.size() + Ins.size() + 4);
+ Ops.resize(outSize + Ins.size() + 4);
Ops[0] = Chain;
// Identify the callee function
- const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
- assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
- "PTX function calls must be to PTX device functions");
Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
Ops[Ins.size()+2] = Callee;
- // Generate STORE_PARAM nodes for each function argument. In PTX, function
- // arguments are explicitly stored into .param variables and passed as
- // arguments. There is no register/stack-based calling convention in PTX.
- Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
- for (unsigned i = 0; i != OutVals.size(); ++i) {
- unsigned Size = OutVals[i].getValueType().getSizeInBits();
- unsigned Param = PM.addLocalParam(Size);
- const std::string &ParamName = PM.getParamName(Param);
- SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
- MVT::Other);
+ // #Outs
+ Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
+
+ if (isPrintf) {
+ // first argument is the address of the global string variable in memory
+ unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
+ SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(),
+ MVT::Other);
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
- ParamValue, OutVals[i]);
- Ops[i+Ins.size()+4] = ParamValue;
- }
+ ParamValue0, OutVals[0]);
+ Ops[Ins.size()+4] = ParamValue0;
+
+ // alignment is the maximum size of all the arguments
+ unsigned alignment = 0;
+ for (unsigned i = 1; i < OutVals.size(); ++i) {
+ alignment = std::max(alignment,
+ OutVals[i].getValueType().getSizeInBits());
+ }
+
+ // size is the alignment multiplied by the number of arguments
+ unsigned size = alignment * (OutVals.size() - 1);
+
+ // second argument is the address of the stack object (unless no arguments)
+ unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
+ SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
+ MVT::Other);
+ Ops[Ins.size()+5] = ParamValue1;
+
+ if (size > 0)
+ {
+ // create a local stack object to store the arguments
+ unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
+ SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
+
+ // store each of the arguments to the stack in turn
+ for (unsigned int i = 1; i != OutVals.size(); i++) {
+ SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
+ Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr,
+ MachinePointerInfo(),
+ false, false, 0);
+ }
+ // copy the address of the local frame index to get the address in non-local space
+ SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex);
+
+ // store this address in the second argument
+ Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr);
+ }
+ }
+ else
+ {
+ // Generate STORE_PARAM nodes for each function argument. In PTX, function
+ // arguments are explicitly stored into .param variables and passed as
+ // arguments. There is no register/stack-based calling convention in PTX.
+ for (unsigned i = 0; i != OutVals.size(); ++i) {
+ unsigned Size = OutVals[i].getValueType().getSizeInBits();
+ unsigned Param = PM.addLocalParam(Size);
+ const std::string &ParamName = PM.getParamName(Param);
+ SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
+ MVT::Other);
+ Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+ ParamValue, OutVals[i]);
+ Ops[i+Ins.size()+4] = ParamValue;
+ }
+ }
+
std::vector<SDValue> InParams;
// Generate list of .param variables to hold the return value(s).