From 96e6458903ab0799542365cac98653c207984162 Mon Sep 17 00:00:00 2001 From: Dan Bailey Date: Fri, 11 Nov 2011 14:45:12 +0000 Subject: allow non-device function calls in PTX when natively handling device-side printf git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144388 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp | 18 ++++- lib/Target/PTX/PTXAsmPrinter.cpp | 30 ++++++++ lib/Target/PTX/PTXAsmPrinter.h | 2 +- lib/Target/PTX/PTXISelLowering.cpp | 102 +++++++++++++++++++++----- 4 files changed, 129 insertions(+), 23 deletions(-) (limited to 'lib/Target/PTX') diff --git a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp index aabb404dad..2f6c92d11c 100644 --- a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp +++ b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp @@ -96,9 +96,23 @@ void PTXInstPrinter::printCall(const MCInst *MI, raw_ostream &O) { O << "), "; } - O << *(MI->getOperand(Index++).getExpr()) << ", ("; - + const MCExpr* Expr = MI->getOperand(Index++).getExpr(); unsigned NumArgs = MI->getOperand(Index++).getImm(); + + // if the function call is to printf or puts, change to vprintf + if (const MCSymbolRefExpr *SymRefExpr = dyn_cast(Expr)) { + const MCSymbol &Sym = SymRefExpr->getSymbol(); + if (Sym.getName() == "printf" || Sym.getName() == "puts") { + O << "vprintf"; + } else { + O << Sym.getName(); + } + } else { + O << *Expr; + } + + O << ", ("; + if (NumArgs > 0) { printOperand(MI, Index++, O); for (unsigned i = 1; i < NumArgs; ++i) { diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index 45a6afc858..bdf238b1b0 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -165,6 +165,11 @@ void PTXAsmPrinter::EmitStartOfAsmFile(Module &M) OutStreamer.AddBlankLine(); + // declare external functions + for (Module::const_iterator i = M.begin(), e = M.end(); + i != e; ++i) + EmitFunctionDeclaration(i); + // declare global variables for (Module::const_global_iterator i = M.global_begin(), e = M.global_end(); i != e; ++i) @@ -454,6 +459,31 @@ void PTXAsmPrinter::EmitFunctionEntryLabel() { OutStreamer.EmitRawText(os.str()); } +void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func) +{ + const PTXSubtarget& ST = TM.getSubtarget(); + + std::string decl = ""; + + // hard-coded emission of extern vprintf function + + if (func->getName() == "printf" || func->getName() == "puts") { + decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b"; + if (ST.is64Bit()) + decl += "64"; + else + decl += "32"; + decl += " __param_2, .param .b"; + if (ST.is64Bit()) + decl += "64"; + else + decl += "32"; + decl += " __param_3)\n"; + } + + OutStreamer.EmitRawText(Twine(decl)); +} + unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName, StringRef DirName) { // If FE did not provide a file name, then assume stdin. diff --git a/lib/Target/PTX/PTXAsmPrinter.h b/lib/Target/PTX/PTXAsmPrinter.h index 538c0802a2..d5ea4dbc59 100644 --- a/lib/Target/PTX/PTXAsmPrinter.h +++ b/lib/Target/PTX/PTXAsmPrinter.h @@ -47,7 +47,7 @@ public: private: void EmitVariableDeclaration(const GlobalVariable *gv); - void EmitFunctionDeclaration(); + void EmitFunctionDeclaration(const Function* func); StringMap SourceIdMap; }; // class PTXAsmPrinter diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 3307d91a61..7f55871f63 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" @@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, SmallVectorImpl &InVals) const { MachineFunction& MF = DAG.getMachineFunction(); - PTXMachineFunctionInfo *MFI = MF.getInfo(); - PTXParamManager &PM = MFI->getParamManager(); - + PTXMachineFunctionInfo *PTXMFI = MF.getInfo(); + PTXParamManager &PM = PTXMFI->getParamManager(); + MachineFrameInfo *MFI = MF.getFrameInfo(); + assert(getTargetMachine().getSubtarget().callsAreHandled() && "Calls are not handled for the target device"); + // Identify the callee function + const GlobalValue *GV = cast(Callee)->getGlobal(); + const Function *function = cast(GV); + + // allow non-device calls only for printf + bool isPrintf = function->getName() == "printf" || function->getName() == "puts"; + + assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) && + "PTX function calls must be to PTX device functions"); + + unsigned outSize = isPrintf ? 2 : Outs.size(); + std::vector Ops; // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] - Ops.resize(Outs.size() + Ins.size() + 4); + Ops.resize(outSize + Ins.size() + 4); Ops[0] = Chain; // Identify the callee function - const GlobalValue *GV = cast(Callee)->getGlobal(); - assert(cast(GV)->getCallingConv() == CallingConv::PTX_Device && - "PTX function calls must be to PTX device functions"); Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); Ops[Ins.size()+2] = Callee; - // Generate STORE_PARAM nodes for each function argument. In PTX, function - // arguments are explicitly stored into .param variables and passed as - // arguments. There is no register/stack-based calling convention in PTX. - Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); - for (unsigned i = 0; i != OutVals.size(); ++i) { - unsigned Size = OutVals[i].getValueType().getSizeInBits(); - unsigned Param = PM.addLocalParam(Size); - const std::string &ParamName = PM.getParamName(Param); - SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), - MVT::Other); + // #Outs + Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32); + + if (isPrintf) { + // first argument is the address of the global string variable in memory + unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(), + MVT::Other); Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, - ParamValue, OutVals[i]); - Ops[i+Ins.size()+4] = ParamValue; - } + ParamValue0, OutVals[0]); + Ops[Ins.size()+4] = ParamValue0; + + // alignment is the maximum size of all the arguments + unsigned alignment = 0; + for (unsigned i = 1; i < OutVals.size(); ++i) { + alignment = std::max(alignment, + OutVals[i].getValueType().getSizeInBits()); + } + + // size is the alignment multiplied by the number of arguments + unsigned size = alignment * (OutVals.size() - 1); + + // second argument is the address of the stack object (unless no arguments) + unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(), + MVT::Other); + Ops[Ins.size()+5] = ParamValue1; + + if (size > 0) + { + // create a local stack object to store the arguments + unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false); + SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy()); + + // store each of the arguments to the stack in turn + for (unsigned int i = 1; i != OutVals.size(); i++) { + SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy())); + Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr, + MachinePointerInfo(), + false, false, 0); + } + // copy the address of the local frame index to get the address in non-local space + SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex); + + // store this address in the second argument + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr); + } + } + else + { + // Generate STORE_PARAM nodes for each function argument. In PTX, function + // arguments are explicitly stored into .param variables and passed as + // arguments. There is no register/stack-based calling convention in PTX. + for (unsigned i = 0; i != OutVals.size(); ++i) { + unsigned Size = OutVals[i].getValueType().getSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + ParamValue, OutVals[i]); + Ops[i+Ins.size()+4] = ParamValue; + } + } + std::vector InParams; // Generate list of .param variables to hold the return value(s). -- cgit v1.2.3-18-g5258