1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-18 18:42:46 +02:00

[CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index

This helps SelectionDAGBuilder recognize the splat can be used as a uniform base.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D86371
This commit is contained in:
Craig Topper 2020-09-02 20:44:12 -07:00
parent e667967956
commit 639e60808d
5 changed files with 128 additions and 106 deletions

View File

@ -358,7 +358,7 @@ int getSplatIndex(ArrayRef<int> Mask);
/// Get splat value if the input is a splat vector or return nullptr.
/// The value may be extracted from a splat constants vector or from
/// a sequence of instructions that broadcast a single value into a vector.
const Value *getSplatValue(const Value *V);
Value *getSplatValue(const Value *V);
/// Return true if each element of the vector value \p V is poisoned or equal to
/// every other non-poisoned element. If an index element is specified, either

View File

@ -342,7 +342,7 @@ int llvm::getSplatIndex(ArrayRef<int> Mask) {
/// This function is not fully general. It checks only 2 cases:
/// the input value is (1) a splat constant vector or (2) a sequence
/// of instructions that broadcasts a scalar at element 0.
const llvm::Value *llvm::getSplatValue(const Value *V) {
Value *llvm::getSplatValue(const Value *V) {
if (isa<VectorType>(V->getType()))
if (auto *C = dyn_cast<Constant>(V))
return C->getSplatValue();

View File

@ -5314,8 +5314,15 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// zero index.
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Value *Ptr) {
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
if (!GEP || !GEP->hasIndices())
// FIXME: Support scalable vectors.
if (isa<ScalableVectorType>(Ptr->getType()))
return false;
Value *NewAddr;
if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
// Don't optimize GEPs that don't have indices.
if (!GEP->hasIndices())
return false;
// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
@ -5328,7 +5335,7 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
bool RewriteGEP = false;
if (Ops[0]->getType()->isVectorTy()) {
Ops[0] = const_cast<Value *>(getSplatValue(Ops[0]));
Ops[0] = getSplatValue(Ops[0]);
if (!Ops[0])
return false;
RewriteGEP = true;
@ -5354,7 +5361,7 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
// Try to scalarize the final index.
if (Ops[FinalIndex]->getType()->isVectorTy()) {
if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) {
if (Value *V = getSplatValue(Ops[FinalIndex])) {
auto *C = dyn_cast<ConstantInt>(V);
// Don't scalarize all zeros vector.
if (!C || !C->isZero()) {
@ -5375,8 +5382,6 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
Value *NewAddr;
// If the final index isn't a vector, emit a scalar GEP containing all ops
// and a vector GEP with all zeroes final index.
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
@ -5397,6 +5402,25 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
// Now create the GEP with scalar pointer and vector index.
NewAddr = Builder.CreateGEP(Base, Index);
}
} else if (!isa<Constant>(Ptr)) {
// Not a GEP, maybe its a splat and we can create a GEP to enable
// SelectionDAGBuilder to use it as a uniform base.
Value *V = getSplatValue(Ptr);
if (!V)
return false;
unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
IRBuilder<> Builder(MemoryInst);
// Emit a vector GEP with a scalar pointer and all 0s vector index.
Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy));
} else {
// Constant, SelectionDAGBuilder knows to check if its a splat.
return false;
}
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);

View File

@ -3323,14 +3323,13 @@ define void @scatter_16i64_constant_indices(i32* %ptr, <16 x i1> %mask, <16 x i3
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; KNL_64-LABEL: splat_ptr_gather:
; KNL_64: # %bb.0:
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
; KNL_64-NEXT: vmovq %rdi, %xmm0
; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
; KNL_64-NEXT: vpgatherqd (,%zmm0), %ymm1 {%k1}
; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_64-NEXT: vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1}
; KNL_64-NEXT: vmovdqa %xmm1, %xmm0
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
@ -3342,8 +3341,9 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
; KNL_32-NEXT: vpgatherdd (,%zmm0), %zmm1 {%k1}
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_32-NEXT: vpgatherdd (%eax,%zmm0,4), %zmm1 {%k1}
; KNL_32-NEXT: vmovdqa %xmm1, %xmm0
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
@ -3352,18 +3352,18 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
; SKX: # %bb.0:
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
; SKX-NEXT: vpmovd2m %xmm0, %k1
; SKX-NEXT: vpbroadcastq %rdi, %ymm0
; SKX-NEXT: vpgatherqd (,%ymm0), %xmm1 {%k1}
; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
; SKX-NEXT: vmovdqa %xmm1, %xmm0
; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
; SKX_32-LABEL: splat_ptr_gather:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
; SKX_32-NEXT: vpgatherdd (,%xmm0), %xmm1 {%k1}
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX_32-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
; SKX_32-NEXT: vmovdqa %xmm1, %xmm0
; SKX_32-NEXT: retl
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
@ -3376,14 +3376,13 @@ declare <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>,
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; KNL_64-LABEL: splat_ptr_scatter:
; KNL_64: # %bb.0:
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
; KNL_64-NEXT: vmovq %rdi, %xmm0
; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
; KNL_64-NEXT: vpscatterqd %ymm1, (,%zmm0) {%k1}
; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1}
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
;
@ -3394,8 +3393,9 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
; KNL_32-NEXT: vpscatterdd %zmm1, (,%zmm0) {%k1}
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1}
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
;
@ -3403,17 +3403,17 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; SKX: # %bb.0:
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
; SKX-NEXT: vpmovd2m %xmm0, %k1
; SKX-NEXT: vpbroadcastq %rdi, %ymm0
; SKX-NEXT: vpscatterqd %xmm1, (,%ymm0) {%k1}
; SKX-NEXT: vzeroupper
; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
; SKX-NEXT: retq
;
; SKX_32-LABEL: splat_ptr_scatter:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
; SKX_32-NEXT: vpscatterdd %xmm1, (,%xmm0) {%k1}
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
; SKX_32-NEXT: retl
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer

View File

@ -87,10 +87,9 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP3]]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer
@ -100,9 +99,8 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0