diff options
author | Justin Holewinski <justin.holewinski@gmail.com> | 2011-09-23 16:48:41 +0000 |
---|---|---|
committer | Justin Holewinski <justin.holewinski@gmail.com> | 2011-09-23 16:48:41 +0000 |
commit | 75d809599b52dc13c41b5b7afebc5b4f078395b3 (patch) | |
tree | 256347d6590d5dc2d5ef450d58384ef012f90338 | |
parent | 0353dab90ec502b02cbf2cee845e07d51627248b (diff) |
PTX: Handle function call return values
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140386 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r-- | lib/Target/PTX/PTXAsmPrinter.cpp | 33 | ||||
-rw-r--r-- | lib/Target/PTX/PTXISelLowering.cpp | 40 | ||||
-rw-r--r-- | test/CodeGen/PTX/simple-call.ll | 13 |
3 files changed, 68 insertions, 18 deletions
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index 77164cac88..d2b7c5f6b5 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -677,21 +677,36 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { void PTXAsmPrinter:: printCall(const MachineInstr *MI, raw_ostream &O) { - O << "\tcall.uni\t"; + // The first two operands are the predicate slot + unsigned Index = 2; + while (!MI->getOperand(Index).isGlobal()) { + if (Index == 2) { + O << "("; + } else { + O << ", "; + } + printParamOperand(MI, Index, O); + Index++; + } - const GlobalValue *Address = MI->getOperand(2).getGlobal(); - O << Address->getName() << ", ("; + if (Index != 2) { + O << "), "; + } - // (0,1) : predicate register/flag - // (2) : callee - for (unsigned i = 3; i < MI->getNumOperands(); ++i) { - //const MachineOperand& MO = MI->getOperand(i); + assert(MI->getOperand(Index).isGlobal() && + "A GlobalAddress must follow the return arguments"); + + const GlobalValue *Address = MI->getOperand(Index).getGlobal(); + O << Address->getName() << ", ("; + Index++; - printParamOperand(MI, i, O); - if (i < MI->getNumOperands()-1) { + while (Index < MI->getNumOperands()) { + printParamOperand(MI, Index, O); + if (Index < MI->getNumOperands()-1) { O << ", "; } + Index++; } O << ")"; diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 3fdfcdf574..053e140efe 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -16,6 +16,7 @@ #include "PTXMachineFunctionInfo.h" #include "PTXRegisterInfo.h" #include "PTXSubtarget.h" +#include "llvm/Function.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" @@ -440,15 +441,22 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() && "Calls are not handled for the target device"); - // Is there a more "LLVM"-way to create a variable-length array of values? - SDValue* ops = new SDValue[OutVals.size() + 2]; + std::vector<SDValue> Ops; + // The layout of the ops will be [Chain, Ins, Callee, Outs] + Ops.resize(Outs.size() + Ins.size() + 2); - ops[0] = Chain; + Ops[0] = Chain; if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) { const GlobalValue *GV = G->getGlobal(); - Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); - ops[1] = Callee; + if (const Function *F = dyn_cast<Function>(GV)) { + assert(F->getCallingConv() == CallingConv::PTX_Device && + "PTX function calls must be to PTX device functions"); + Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); + Ops[Ins.size()+1] = Callee; + } else { + assert(false && "GlobalValue is not a function"); + } } else { assert(false && "Function must be a GlobalAddressSDNode"); } @@ -459,14 +467,28 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, SDValue Index = DAG.getTargetConstant(Param, MVT::i32); Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, Index, OutVals[i]); - ops[i+2] = Index; + Ops[i+Ins.size()+2] = Index; } - ops[0] = Chain; + std::vector<unsigned> InParams; - Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2); + for (unsigned i = 0; i < Ins.size(); ++i) { + unsigned Size = Ins[i].VT.getStoreSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + SDValue Index = DAG.getTargetConstant(Param, MVT::i32); + Ops[i+1] = Index; + InParams.push_back(Param); + } - delete [] ops; + Ops[0] = Chain; + + Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size()); + + for (unsigned i = 0; i < Ins.size(); ++i) { + SDValue Index = DAG.getTargetConstant(InParams[i], MVT::i32); + SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, Index); + InVals.push_back(Load); + } return Chain; } diff --git a/test/CodeGen/PTX/simple-call.ll b/test/CodeGen/PTX/simple-call.ll index 1e980655d3..77ea29eae8 100644 --- a/test/CodeGen/PTX/simple-call.ll +++ b/test/CodeGen/PTX/simple-call.ll @@ -12,3 +12,16 @@ define ptx_device float @test_call(float %x, float %y) { call void @test_add(float %a, float %y) ret float %a } + +define ptx_device float @test_compute(float %x, float %y) { +; CHECK: ret; + %z = fadd float %x, %y + ret float %z +} + +define ptx_device float @test_call_compute(float %x, float %y) { +; CHECK: call.uni (__localparam_{{[0-9]+}}), test_compute, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}}) + %z = call float @test_compute(float %x, float %y) + ret float %z +} + |