diff options
author | Justin Holewinski <justin.holewinski@gmail.com> | 2011-08-09 17:36:31 +0000 |
---|---|---|
committer | Justin Holewinski <justin.holewinski@gmail.com> | 2011-08-09 17:36:31 +0000 |
commit | 4bdd4ed5647f2f9a7b0ccdf6aba920b08ef7b153 (patch) | |
tree | b22c218f4dcf3298477d24d0e0f25edb1cf88bf5 | |
parent | 6d1fd0b979cb88809ebb77a24f4da69e1d67606b (diff) |
PTX: Add initial support for device function calls
- Calls are supported on SM 2.0+ for function with no return values
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@137125 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r-- | lib/Target/PTX/PTXAsmPrinter.cpp | 43 | ||||
-rw-r--r-- | lib/Target/PTX/PTXISelLowering.cpp | 49 | ||||
-rw-r--r-- | lib/Target/PTX/PTXISelLowering.h | 13 | ||||
-rw-r--r-- | lib/Target/PTX/PTXInstrInfo.td | 26 | ||||
-rw-r--r-- | lib/Target/PTX/PTXMachineFunctionInfo.h | 9 | ||||
-rw-r--r-- | lib/Target/PTX/PTXSubtarget.h | 7 | ||||
-rw-r--r-- | test/CodeGen/PTX/simple-call.ll | 14 |
7 files changed, 158 insertions, 3 deletions
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index bb48e0ab4b..fc0ec70199 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -70,6 +70,8 @@ public: const char *Modifier = 0); void printPredicateOperand(const MachineInstr *MI, raw_ostream &O); + void printCall(const MachineInstr *MI, raw_ostream &O); + unsigned GetOrCreateSourceID(StringRef FileName, StringRef DirName); @@ -242,6 +244,19 @@ void PTXAsmPrinter::EmitFunctionBodyStart() { OutStreamer.EmitRawText(Twine(def)); } } + + unsigned Index = 1; + // Print parameter passing params + for (PTXMachineFunctionInfo::param_iterator + i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) { + std::string def = "\t.param .b"; + def += utostr(*i); + def += " __ret_"; + def += utostr(Index); + Index++; + def += ";"; + OutStreamer.EmitRawText(Twine(def)); + } } void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { @@ -302,7 +317,11 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { printPredicateOperand(MI, OS); // Write instruction to str - printInstruction(MI, OS); + if (MI->getOpcode() == PTX::CALL) { + printCall(MI, OS); + } else { + printInstruction(MI, OS); + } OS << ';'; OS.flush(); @@ -569,6 +588,28 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { } } +void PTXAsmPrinter:: +printCall(const MachineInstr *MI, raw_ostream &O) { + + O << "\tcall.uni\t"; + + const GlobalValue *Address = MI->getOperand(2).getGlobal(); + O << Address->getName() << ", ("; + + // (0,1) : predicate register/flag + // (2) : callee + for (unsigned i = 3; i < MI->getNumOperands(); ++i) { + //const MachineOperand& MO = MI->getOperand(i); + + printReturnOperand(MI, i, O); + if (i < MI->getNumOperands()-1) { + O << ", "; + } + } + + O << ")"; +} + unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName, StringRef DirName) { // If FE did not provide a file name, then assume stdin. diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 6fcf710e3f..d52aa2a01a 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -22,6 +22,7 @@ #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -134,6 +135,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "PTXISD::EXIT"; case PTXISD::RET: return "PTXISD::RET"; + case PTXISD::CALL: + return "PTXISD::CALL"; } } @@ -345,3 +348,49 @@ SDValue PTXTargetLowering:: return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag); } } + +SDValue +PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, + CallingConv::ID CallConv, bool isVarArg, + bool &isTailCall, + const SmallVectorImpl<ISD::OutputArg> &Outs, + const SmallVectorImpl<SDValue> &OutVals, + const SmallVectorImpl<ISD::InputArg> &Ins, + DebugLoc dl, SelectionDAG &DAG, + SmallVectorImpl<SDValue> &InVals) const { + + MachineFunction& MF = DAG.getMachineFunction(); + PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); + const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); + + assert(ST.callsAreHandled() && "Calls are not handled for the target device"); + + // Is there a more "LLVM"-way to create a variable-length array of values? + SDValue* ops = new SDValue[OutVals.size() + 2]; + + ops[0] = Chain; + + if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) { + const GlobalValue *GV = G->getGlobal(); + Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); + ops[1] = Callee; + } else { + assert(false && "Function must be a GlobalAddressSDNode"); + } + + for (unsigned i = 0; i != OutVals.size(); ++i) { + unsigned Size = OutVals[i].getValueType().getSizeInBits(); + SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + Index, OutVals[i]); + ops[i+2] = Index; + } + + ops[0] = Chain; + + Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2); + + delete [] ops; + + return Chain; +} diff --git a/lib/Target/PTX/PTXISelLowering.h b/lib/Target/PTX/PTXISelLowering.h index 43185416e1..f99ac7bc78 100644 --- a/lib/Target/PTX/PTXISelLowering.h +++ b/lib/Target/PTX/PTXISelLowering.h @@ -28,7 +28,8 @@ namespace PTXISD { STORE_PARAM, EXIT, RET, - COPY_ADDRESS + COPY_ADDRESS, + CALL }; } // namespace PTXISD @@ -60,6 +61,16 @@ class PTXTargetLowering : public TargetLowering { DebugLoc dl, SelectionDAG &DAG) const; + virtual SDValue + LowerCall(SDValue Chain, SDValue Callee, + CallingConv::ID CallConv, bool isVarArg, + bool &isTailCall, + const SmallVectorImpl<ISD::OutputArg> &Outs, + const SmallVectorImpl<SDValue> &OutVals, + const SmallVectorImpl<ISD::InputArg> &Ins, + DebugLoc dl, SelectionDAG &DAG, + SmallVectorImpl<SDValue> &InVals) const; + virtual MVT::SimpleValueType getSetCCResultType(EVT VT) const; private: diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td index 6bfe906d40..11caa7f1f9 100644 --- a/lib/Target/PTX/PTXInstrInfo.td +++ b/lib/Target/PTX/PTXInstrInfo.td @@ -168,6 +168,18 @@ def MEMret : Operand<i32> { let MIOperandInfo = (ops i32imm); } +// def SDT_PTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>]>; +// def SDT_PTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; + +// def PTXcallseq_start : SDNode<"ISD::CALLSEQ_START", SDT_PTXCallSeqStart, +// [SDNPHasChain, SDNPOutGlue]>; +// def PTXcallseq_end : SDNode<"ISD::CALLSEQ_END", SDT_PTXCallSeqEnd, +// [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>; + +def PTXcall : SDNode<"PTXISD::CALL", SDTNone, + [SDNPHasChain, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>; + + // Branch & call targets have OtherVT type. def brtarget : Operand<OtherVT>; def calltarget : Operand<i32>; @@ -1073,6 +1085,11 @@ let isReturn = 1, isTerminator = 1, isBarrier = 1 in { def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>; } +let hasSideEffects = 1 in { + def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>; +} + + ///===- Spill Instructions ------------------------------------------------===// // Special instructions used for stack spilling def STACKSTOREI16 : InstPTX<(outs), (ins i32imm:$d, RegI16:$a), @@ -1097,6 +1114,15 @@ def STACKLOADF32 : InstPTX<(outs), (ins RegF32:$d, i32imm:$a), def STACKLOADF64 : InstPTX<(outs), (ins RegF64:$d, i32imm:$a), "mov.f64\t$d, s$a", []>; + +// Call handling +// def ADJCALLSTACKUP : +// InstPTX<(outs), (ins i32imm:$amt1, i32imm:$amt2), "", +// [(PTXcallseq_end timm:$amt1, timm:$amt2)]>; +// def ADJCALLSTACKDOWN : +// InstPTX<(outs), (ins i32imm:$amt), "", +// [(PTXcallseq_start timm:$amt)]>; + ///===- Intrinsic Instructions --------------------------------------------===// include "PTXIntrinsicInstrInfo.td" diff --git a/lib/Target/PTX/PTXMachineFunctionInfo.h b/lib/Target/PTX/PTXMachineFunctionInfo.h index 9d65f5bd1a..a3b0f324fe 100644 --- a/lib/Target/PTX/PTXMachineFunctionInfo.h +++ b/lib/Target/PTX/PTXMachineFunctionInfo.h @@ -27,6 +27,7 @@ private: bool is_kernel; std::vector<unsigned> reg_arg, reg_local_var; std::vector<unsigned> reg_ret; + std::vector<unsigned> call_params; bool _isDoneAddArg; public: @@ -56,6 +57,7 @@ public: typedef std::vector<unsigned>::const_iterator reg_iterator; typedef std::vector<unsigned>::const_reverse_iterator reg_reverse_iterator; typedef std::vector<unsigned>::const_iterator ret_iterator; + typedef std::vector<unsigned>::const_iterator param_iterator; bool argRegEmpty() const { return reg_arg.empty(); } int getNumArg() const { return reg_arg.size(); } @@ -73,6 +75,13 @@ public: ret_iterator retRegBegin() const { return reg_ret.begin(); } ret_iterator retRegEnd() const { return reg_ret.end(); } + param_iterator paramBegin() const { return call_params.begin(); } + param_iterator paramEnd() const { return call_params.end(); } + unsigned getNextParam(unsigned size) { + call_params.push_back(size); + return call_params.size()-1; + } + bool isArgReg(unsigned reg) const { return std::find(reg_arg.begin(), reg_arg.end(), reg) != reg_arg.end(); } diff --git a/lib/Target/PTX/PTXSubtarget.h b/lib/Target/PTX/PTXSubtarget.h index 0921f1f22c..0404200992 100644 --- a/lib/Target/PTX/PTXSubtarget.h +++ b/lib/Target/PTX/PTXSubtarget.h @@ -114,7 +114,12 @@ class StringRef; (PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE); } - void ParseSubtargetFeatures(StringRef CPU, StringRef FS); + bool callsAreHandled() const { + return (PTXTarget >= PTX_SM_2_0 && PTXTarget < PTX_LAST_SM) || + (PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE); + } + + void ParseSubtargetFeatures(StringRef CPU, StringRef FS); }; // class PTXSubtarget } // namespace llvm diff --git a/test/CodeGen/PTX/simple-call.ll b/test/CodeGen/PTX/simple-call.ll new file mode 100644 index 0000000000..36f6d8c2a9 --- /dev/null +++ b/test/CodeGen/PTX/simple-call.ll @@ -0,0 +1,14 @@ +; RUN: llc < %s -march=ptx32 -mattr=sm20 | FileCheck %s + +define ptx_device void @test_add(float %x, float %y) { +; CHECK: ret; + %z = fadd float %x, %y + ret void +} + +define ptx_device float @test_call(float %x, float %y) { + %a = fadd float %x, %y +; CHECK: call.uni test_add, (__ret_{{[0-9]+}}, __ret_{{[0-9]+}}); + call void @test_add(float %a, float %y) + ret float %a +} |