aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCraig Topper <craig.topper@gmail.com>2011-11-19 07:07:26 +0000
committerCraig Topper <craig.topper@gmail.com>2011-11-19 07:07:26 +0000
commit1666cb6d637af89a752d2be938be53be5253bdfd (patch)
tree2d4e1908aa224a8ec6fa456b9207a4a1a1474957
parent787a88ff18b3fe9e8f66eda63544eb10a0c806ef (diff)
Extend VPBLENDVB and VPSIGN lowering to work for AVX2.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144987 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r--lib/Target/X86/X86ISelLowering.cpp169
-rw-r--r--lib/Target/X86/X86InstrFragmentsSIMD.td6
-rw-r--r--lib/Target/X86/X86InstrSSE.td63
-rw-r--r--test/CodeGen/X86/avx2-logic.ll29
4 files changed, 156 insertions, 111 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 6a14f220a4..b45d3f617a 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -13859,98 +13859,105 @@ static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG,
return R;
EVT VT = N->getValueType(0);
- if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64 && VT != MVT::v2i64)
- return SDValue();
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
// look for psign/blend
- if (Subtarget->hasSSSE3() || Subtarget->hasAVX()) {
- if (VT == MVT::v2i64) {
- // Canonicalize pandn to RHS
- if (N0.getOpcode() == X86ISD::ANDNP)
- std::swap(N0, N1);
- // or (and (m, x), (pandn m, y))
- if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
- SDValue Mask = N1.getOperand(0);
- SDValue X = N1.getOperand(1);
- SDValue Y;
- if (N0.getOperand(0) == Mask)
- Y = N0.getOperand(1);
- if (N0.getOperand(1) == Mask)
- Y = N0.getOperand(0);
-
- // Check to see if the mask appeared in both the AND and ANDNP and
- if (!Y.getNode())
- return SDValue();
-
- // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
- if (Mask.getOpcode() != ISD::BITCAST ||
- X.getOpcode() != ISD::BITCAST ||
- Y.getOpcode() != ISD::BITCAST)
- return SDValue();
-
- // Look through mask bitcast.
- Mask = Mask.getOperand(0);
- EVT MaskVT = Mask.getValueType();
-
- // Validate that the Mask operand is a vector sra node. The sra node
- // will be an intrinsic.
- if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
- return SDValue();
-
- // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
- // there is no psrai.b
- switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
- case Intrinsic::x86_sse2_psrai_w:
- case Intrinsic::x86_sse2_psrai_d:
- break;
- default: return SDValue();
- }
+ if (VT == MVT::v2i64 || VT == MVT::v4i64) {
+ if (!(Subtarget->hasSSSE3() || Subtarget->hasAVX()) ||
+ (VT == MVT::v4i64 && !Subtarget->hasAVX2()))
+ return SDValue();
- // Check that the SRA is all signbits.
- SDValue SraC = Mask.getOperand(2);
- unsigned SraAmt = cast<ConstantSDNode>(SraC)->getZExtValue();
- unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
- if ((SraAmt + 1) != EltBits)
- return SDValue();
-
- DebugLoc DL = N->getDebugLoc();
-
- // Now we know we at least have a plendvb with the mask val. See if
- // we can form a psignb/w/d.
- // psign = x.type == y.type == mask.type && y = sub(0, x);
- X = X.getOperand(0);
- Y = Y.getOperand(0);
- if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
- ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
- X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
- unsigned Opc = 0;
- switch (EltBits) {
- case 8: Opc = X86ISD::PSIGNB; break;
- case 16: Opc = X86ISD::PSIGNW; break;
- case 32: Opc = X86ISD::PSIGND; break;
- default: break;
- }
- if (Opc) {
- SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
- return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Sign);
- }
+ // Canonicalize pandn to RHS
+ if (N0.getOpcode() == X86ISD::ANDNP)
+ std::swap(N0, N1);
+ // or (and (m, x), (pandn m, y))
+ if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
+ SDValue Mask = N1.getOperand(0);
+ SDValue X = N1.getOperand(1);
+ SDValue Y;
+ if (N0.getOperand(0) == Mask)
+ Y = N0.getOperand(1);
+ if (N0.getOperand(1) == Mask)
+ Y = N0.getOperand(0);
+
+ // Check to see if the mask appeared in both the AND and ANDNP and
+ if (!Y.getNode())
+ return SDValue();
+
+ // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
+ if (Mask.getOpcode() != ISD::BITCAST ||
+ X.getOpcode() != ISD::BITCAST ||
+ Y.getOpcode() != ISD::BITCAST)
+ return SDValue();
+
+ // Look through mask bitcast.
+ Mask = Mask.getOperand(0);
+ EVT MaskVT = Mask.getValueType();
+
+ // Validate that the Mask operand is a vector sra node. The sra node
+ // will be an intrinsic.
+ if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+ return SDValue();
+
+ // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
+ // there is no psrai.b
+ switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
+ case Intrinsic::x86_sse2_psrai_w:
+ case Intrinsic::x86_sse2_psrai_d:
+ case Intrinsic::x86_avx2_psrai_w:
+ case Intrinsic::x86_avx2_psrai_d:
+ break;
+ default: return SDValue();
+ }
+
+ // Check that the SRA is all signbits.
+ SDValue SraC = Mask.getOperand(2);
+ unsigned SraAmt = cast<ConstantSDNode>(SraC)->getZExtValue();
+ unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
+ if ((SraAmt + 1) != EltBits)
+ return SDValue();
+
+ DebugLoc DL = N->getDebugLoc();
+
+ // Now we know we at least have a plendvb with the mask val. See if
+ // we can form a psignb/w/d.
+ // psign = x.type == y.type == mask.type && y = sub(0, x);
+ X = X.getOperand(0);
+ Y = Y.getOperand(0);
+ if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
+ ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
+ X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
+ unsigned Opc = 0;
+ switch (EltBits) {
+ case 8: Opc = X86ISD::PSIGNB; break;
+ case 16: Opc = X86ISD::PSIGNW; break;
+ case 32: Opc = X86ISD::PSIGND; break;
+ default: break;
+ }
+ if (Opc) {
+ SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
+ return DAG.getNode(ISD::BITCAST, DL, VT, Sign);
}
- // PBLENDVB only available on SSE 4.1
- if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
- return SDValue();
-
- X = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, X);
- Y = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Y);
- Mask = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Mask);
- Mask = DAG.getNode(ISD::VSELECT, DL, MVT::v16i8, Mask, X, Y);
- return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Mask);
}
+ // PBLENDVB only available on SSE 4.1
+ if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
+ return SDValue();
+
+ EVT BlendVT = (VT == MVT::v4i64) ? MVT::v32i8 : MVT::v16i8;
+
+ X = DAG.getNode(ISD::BITCAST, DL, BlendVT, X);
+ Y = DAG.getNode(ISD::BITCAST, DL, BlendVT, Y);
+ Mask = DAG.getNode(ISD::BITCAST, DL, BlendVT, Mask);
+ Mask = DAG.getNode(ISD::VSELECT, DL, BlendVT, Mask, X, Y);
+ return DAG.getNode(ISD::BITCAST, DL, VT, Mask);
}
}
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+
// fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c)
if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
std::swap(N0, N1);
diff --git a/lib/Target/X86/X86InstrFragmentsSIMD.td b/lib/Target/X86/X86InstrFragmentsSIMD.td
index 6fd2efdab8..3a2ba184d4 100644
--- a/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -52,13 +52,13 @@ def X86andnp : SDNode<"X86ISD::ANDNP",
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisSameAs<0,2>]>>;
def X86psignb : SDNode<"X86ISD::PSIGNB",
- SDTypeProfile<1, 2, [SDTCisVT<0, v16i8>, SDTCisSameAs<0,1>,
+ SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisSameAs<0,2>]>>;
def X86psignw : SDNode<"X86ISD::PSIGNW",
- SDTypeProfile<1, 2, [SDTCisVT<0, v8i16>, SDTCisSameAs<0,1>,
+ SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisSameAs<0,2>]>>;
def X86psignd : SDNode<"X86ISD::PSIGND",
- SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisSameAs<0,1>,
+ SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisSameAs<0,2>]>>;
def X86pextrb : SDNode<"X86ISD::PEXTRB",
SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisPtrTy<2>]>>;
diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td
index c66240fae4..6be366bf13 100644
--- a/lib/Target/X86/X86InstrSSE.td
+++ b/lib/Target/X86/X86InstrSSE.td
@@ -3824,51 +3824,51 @@ let ExeDomain = SSEPackedInt in {
let Predicates = [HasAVX] in {
def : Pat<(int_x86_sse2_psll_dq VR128:$src1, imm:$src2),
- (v2i64 (VPSLLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (VPSLLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_sse2_psrl_dq VR128:$src1, imm:$src2),
- (v2i64 (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_sse2_psll_dq_bs VR128:$src1, imm:$src2),
- (v2i64 (VPSLLDQri VR128:$src1, imm:$src2))>;
+ (VPSLLDQri VR128:$src1, imm:$src2)>;
def : Pat<(int_x86_sse2_psrl_dq_bs VR128:$src1, imm:$src2),
- (v2i64 (VPSRLDQri VR128:$src1, imm:$src2))>;
+ (VPSRLDQri VR128:$src1, imm:$src2)>;
def : Pat<(v2f64 (X86fsrl VR128:$src1, i32immSExt8:$src2)),
- (v2f64 (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
// Shift up / down and insert zero's.
def : Pat<(v2i64 (X86vshl VR128:$src, (i8 imm:$amt))),
- (v2i64 (VPSLLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+ (VPSLLDQri VR128:$src, (BYTE_imm imm:$amt))>;
def : Pat<(v2i64 (X86vshr VR128:$src, (i8 imm:$amt))),
- (v2i64 (VPSRLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+ (VPSRLDQri VR128:$src, (BYTE_imm imm:$amt))>;
}
let Predicates = [HasAVX2] in {
def : Pat<(int_x86_avx2_psll_dq VR256:$src1, imm:$src2),
- (v4i64 (VPSLLDQYri VR256:$src1, (BYTE_imm imm:$src2)))>;
+ (VPSLLDQYri VR256:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_avx2_psrl_dq VR256:$src1, imm:$src2),
- (v4i64 (VPSRLDQYri VR256:$src1, (BYTE_imm imm:$src2)))>;
+ (VPSRLDQYri VR256:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_avx2_psll_dq_bs VR256:$src1, imm:$src2),
- (v4i64 (VPSLLDQYri VR256:$src1, imm:$src2))>;
+ (VPSLLDQYri VR256:$src1, imm:$src2)>;
def : Pat<(int_x86_avx2_psrl_dq_bs VR256:$src1, imm:$src2),
- (v4i64 (VPSRLDQYri VR256:$src1, imm:$src2))>;
+ (VPSRLDQYri VR256:$src1, imm:$src2)>;
}
let Predicates = [HasSSE2] in {
def : Pat<(int_x86_sse2_psll_dq VR128:$src1, imm:$src2),
- (v2i64 (PSLLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (PSLLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_sse2_psrl_dq VR128:$src1, imm:$src2),
- (v2i64 (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
def : Pat<(int_x86_sse2_psll_dq_bs VR128:$src1, imm:$src2),
- (v2i64 (PSLLDQri VR128:$src1, imm:$src2))>;
+ (PSLLDQri VR128:$src1, imm:$src2)>;
def : Pat<(int_x86_sse2_psrl_dq_bs VR128:$src1, imm:$src2),
- (v2i64 (PSRLDQri VR128:$src1, imm:$src2))>;
+ (PSRLDQri VR128:$src1, imm:$src2)>;
def : Pat<(v2f64 (X86fsrl VR128:$src1, i32immSExt8:$src2)),
- (v2f64 (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+ (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
// Shift up / down and insert zero's.
def : Pat<(v2i64 (X86vshl VR128:$src, (i8 imm:$amt))),
- (v2i64 (PSLLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+ (PSLLDQri VR128:$src, (BYTE_imm imm:$amt))>;
def : Pat<(v2i64 (X86vshr VR128:$src, (i8 imm:$amt))),
- (v2i64 (PSRLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+ (PSRLDQri VR128:$src, (BYTE_imm imm:$amt))>;
}
//===---------------------------------------------------------------------===//
@@ -5316,11 +5316,11 @@ let isCommutable = 0 in {
int_x86_avx2_pmadd_ub_sw>, VEX_4V;
defm VPSHUFB : SS3I_binop_rm_int_y<0x00, "vpshufb", memopv32i8,
int_x86_avx2_pshuf_b>, VEX_4V;
- defm VPSIGNB : SS3I_binop_rm_int_y<0x08, "vpsignb", memopv16i8,
+ defm VPSIGNB : SS3I_binop_rm_int_y<0x08, "vpsignb", memopv32i8,
int_x86_avx2_psign_b>, VEX_4V;
- defm VPSIGNW : SS3I_binop_rm_int_y<0x09, "vpsignw", memopv8i16,
+ defm VPSIGNW : SS3I_binop_rm_int_y<0x09, "vpsignw", memopv16i16,
int_x86_avx2_psign_w>, VEX_4V;
- defm VPSIGND : SS3I_binop_rm_int_y<0x0A, "vpsignd", memopv4i32,
+ defm VPSIGND : SS3I_binop_rm_int_y<0x0A, "vpsignd", memopv8i32,
int_x86_avx2_psign_d>, VEX_4V;
}
defm VPMULHRSW : SS3I_binop_rm_int_y<0x0B, "vpmulhrsw", memopv16i16,
@@ -5363,11 +5363,11 @@ let Predicates = [HasSSSE3] in {
def : Pat<(X86pshufb VR128:$src, (bc_v16i8 (memopv2i64 addr:$mask))),
(PSHUFBrm128 VR128:$src, addr:$mask)>;
- def : Pat<(X86psignb VR128:$src1, VR128:$src2),
+ def : Pat<(v16i8 (X86psignb VR128:$src1, VR128:$src2)),
(PSIGNBrr128 VR128:$src1, VR128:$src2)>;
- def : Pat<(X86psignw VR128:$src1, VR128:$src2),
+ def : Pat<(v8i16 (X86psignw VR128:$src1, VR128:$src2)),
(PSIGNWrr128 VR128:$src1, VR128:$src2)>;
- def : Pat<(X86psignd VR128:$src1, VR128:$src2),
+ def : Pat<(v4i32 (X86psignd VR128:$src1, VR128:$src2)),
(PSIGNDrr128 VR128:$src1, VR128:$src2)>;
}
@@ -5377,14 +5377,23 @@ let Predicates = [HasAVX] in {
def : Pat<(X86pshufb VR128:$src, (bc_v16i8 (memopv2i64 addr:$mask))),
(VPSHUFBrm128 VR128:$src, addr:$mask)>;
- def : Pat<(X86psignb VR128:$src1, VR128:$src2),
+ def : Pat<(v16i8 (X86psignb VR128:$src1, VR128:$src2)),
(VPSIGNBrr128 VR128:$src1, VR128:$src2)>;
- def : Pat<(X86psignw VR128:$src1, VR128:$src2),
+ def : Pat<(v8i16 (X86psignw VR128:$src1, VR128:$src2)),
(VPSIGNWrr128 VR128:$src1, VR128:$src2)>;
- def : Pat<(X86psignd VR128:$src1, VR128:$src2),
+ def : Pat<(v4i32 (X86psignd VR128:$src1, VR128:$src2)),
(VPSIGNDrr128 VR128:$src1, VR128:$src2)>;
}
+let Predicates = [HasAVX2] in {
+ def : Pat<(v32i8 (X86psignb VR256:$src1, VR256:$src2)),
+ (VPSIGNBrr256 VR256:$src1, VR256:$src2)>;
+ def : Pat<(v16i16 (X86psignw VR256:$src1, VR256:$src2)),
+ (VPSIGNWrr256 VR256:$src1, VR256:$src2)>;
+ def : Pat<(v8i32 (X86psignd VR256:$src1, VR256:$src2)),
+ (VPSIGNDrr256 VR256:$src1, VR256:$src2)>;
+}
+
//===---------------------------------------------------------------------===//
// SSSE3 - Packed Align Instruction Patterns
//===---------------------------------------------------------------------===//
diff --git a/test/CodeGen/X86/avx2-logic.ll b/test/CodeGen/X86/avx2-logic.ll
index f1c294c066..aa3d37d553 100644
--- a/test/CodeGen/X86/avx2-logic.ll
+++ b/test/CodeGen/X86/avx2-logic.ll
@@ -53,3 +53,32 @@ define <32 x i8> @vpblendvb(<32 x i8> %x, <32 x i8> %y) {
%min = select <32 x i1> %min_is_x, <32 x i8> %x, <32 x i8> %y
ret <32 x i8> %min
}
+
+define <8 x i32> @signd(<8 x i32> %a, <8 x i32> %b) nounwind {
+entry:
+; CHECK: signd:
+; CHECK: psignd
+; CHECK-NOT: sub
+; CHECK: ret
+ %b.lobit = ashr <8 x i32> %b, <i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31>
+ %sub = sub nsw <8 x i32> zeroinitializer, %a
+ %0 = xor <8 x i32> %b.lobit, <i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1>
+ %1 = and <8 x i32> %a, %0
+ %2 = and <8 x i32> %b.lobit, %sub
+ %cond = or <8 x i32> %1, %2
+ ret <8 x i32> %cond
+}
+
+define <8 x i32> @blendvb(<8 x i32> %b, <8 x i32> %a, <8 x i32> %c) nounwind {
+entry:
+; CHECK: blendvb:
+; CHECK: pblendvb
+; CHECK: ret
+ %b.lobit = ashr <8 x i32> %b, <i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31>
+ %sub = sub nsw <8 x i32> zeroinitializer, %a
+ %0 = xor <8 x i32> %b.lobit, <i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1>
+ %1 = and <8 x i32> %c, %0
+ %2 = and <8 x i32> %a, %b.lobit
+ %cond = or <8 x i32> %1, %2
+ ret <8 x i32> %cond
+}