aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/CodeGen/SelectionDAG/LegalizeTypes.h1
-rw-r--r--lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp60
-rw-r--r--test/CodeGen/NVPTX/vector-select.ll16
3 files changed, 76 insertions, 1 deletions
diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 20b7ce6b15..8464b7d808 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -578,6 +578,7 @@ private:
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
+ SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_UnaryOp(SDNode *N);
SDValue SplitVecOp_BITCAST(SDNode *N);
diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index d51a6eb192..595d83b716 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1030,7 +1030,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::STORE:
Res = SplitVecOp_STORE(cast<StoreSDNode>(N), OpNo);
break;
-
+ case ISD::VSELECT:
+ Res = SplitVecOp_VSELECT(N, OpNo);
+ break;
case ISD::CTTZ:
case ISD::CTLZ:
case ISD::CTPOP:
@@ -1064,6 +1066,62 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
return false;
}
+SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) {
+ // The only possibility for an illegal operand is the mask, since result type
+ // legalization would have handled this node already otherwise.
+ assert(OpNo == 0 && "Illegal operand must be mask");
+
+ SDValue Mask = N->getOperand(0);
+ SDValue Src0 = N->getOperand(1);
+ SDValue Src1 = N->getOperand(2);
+ DebugLoc DL = N->getDebugLoc();
+ EVT MaskVT = Mask.getValueType();
+ assert(MaskVT.isVector() && "VSELECT without a vector mask?");
+
+ SDValue Lo, Hi;
+ GetSplitVector(N->getOperand(0), Lo, Hi);
+
+ unsigned LoNumElts = Lo.getValueType().getVectorNumElements();
+ unsigned HiNumElts = Hi.getValueType().getVectorNumElements();
+ assert(LoNumElts == HiNumElts && "Asymmetric vector split?");
+
+ EVT LoOpVT = EVT::getVectorVT(*DAG.getContext(),
+ Src0.getValueType().getVectorElementType(),
+ LoNumElts);
+ EVT LoMaskVT = EVT::getVectorVT(*DAG.getContext(),
+ MaskVT.getVectorElementType(),
+ LoNumElts);
+ EVT HiOpVT = EVT::getVectorVT(*DAG.getContext(),
+ Src0.getValueType().getVectorElementType(),
+ HiNumElts);
+ EVT HiMaskVT = EVT::getVectorVT(*DAG.getContext(),
+ MaskVT.getVectorElementType(),
+ HiNumElts);
+
+ SDValue LoOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src0,
+ DAG.getIntPtrConstant(0));
+ SDValue LoOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src1,
+ DAG.getIntPtrConstant(0));
+
+ SDValue HiOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src0,
+ DAG.getIntPtrConstant(LoNumElts));
+ SDValue HiOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src1,
+ DAG.getIntPtrConstant(LoNumElts));
+
+ SDValue LoMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoMaskVT, Mask,
+ DAG.getIntPtrConstant(0));
+ SDValue HiMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiMaskVT, Mask,
+ DAG.getIntPtrConstant(LoNumElts));
+
+ SDValue LoSelect = DAG.getNode(ISD::VSELECT, DL, LoOpVT, LoMask, LoOp0,
+ LoOp1);
+ SDValue HiSelect = DAG.getNode(ISD::VSELECT, DL, HiOpVT, HiMask, HiOp0,
+ HiOp1);
+
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0.getValueType(), LoSelect,
+ HiSelect);
+}
+
SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
// The result has a legal vector type, but the input needs splitting.
EVT ResVT = N->getValueType(0);
diff --git a/test/CodeGen/NVPTX/vector-select.ll b/test/CodeGen/NVPTX/vector-select.ll
new file mode 100644
index 0000000000..11893df103
--- /dev/null
+++ b/test/CodeGen/NVPTX/vector-select.ll
@@ -0,0 +1,16 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+; This test makes sure that vector selects are scalarized by the type legalizer.
+; If not, type legalization will fail.
+
+define void @foo(<2 x i32> addrspace(1)* %def_a, <2 x i32> addrspace(1)* %def_b, <2 x i32> addrspace(1)* %def_c) {
+entry:
+ %tmp4 = load <2 x i32> addrspace(1)* %def_a
+ %tmp6 = load <2 x i32> addrspace(1)* %def_c
+ %tmp8 = load <2 x i32> addrspace(1)* %def_b
+ %0 = icmp sge <2 x i32> %tmp4, zeroinitializer
+ %cond = select <2 x i1> %0, <2 x i32> %tmp6, <2 x i32> %tmp8
+ store <2 x i32> %cond, <2 x i32> addrspace(1)* %def_c
+ ret void
+}