From 639e60808d04613fb5404e898cc5fc8455d44bdc Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 2 Sep 2020 20:44:12 -0700 Subject: [PATCH] [CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index This helps SelectionDAGBuilder recognize the splat can be used as a uniform base. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D86371 --- include/llvm/Analysis/VectorUtils.h | 2 +- lib/Analysis/VectorUtils.cpp | 2 +- lib/CodeGen/CodeGenPrepare.cpp | 174 ++++++++++-------- test/CodeGen/X86/masked_gather_scatter.ll | 44 ++--- .../CodeGenPrepare/X86/gather-scatter-opt.ll | 12 +- 5 files changed, 128 insertions(+), 106 deletions(-) diff --git a/include/llvm/Analysis/VectorUtils.h b/include/llvm/Analysis/VectorUtils.h index 074960e7ced..8498335bf78 100644 --- a/include/llvm/Analysis/VectorUtils.h +++ b/include/llvm/Analysis/VectorUtils.h @@ -358,7 +358,7 @@ int getSplatIndex(ArrayRef Mask); /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. -const Value *getSplatValue(const Value *V); +Value *getSplatValue(const Value *V); /// Return true if each element of the vector value \p V is poisoned or equal to /// every other non-poisoned element. If an index element is specified, either diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index 0bc8b7281d9..e241300dd2e 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -342,7 +342,7 @@ int llvm::getSplatIndex(ArrayRef Mask) { /// This function is not fully general. It checks only 2 cases: /// the input value is (1) a splat constant vector or (2) a sequence /// of instructions that broadcasts a scalar at element 0. -const llvm::Value *llvm::getSplatValue(const Value *V) { +Value *llvm::getSplatValue(const Value *V) { if (isa(V->getType())) if (auto *C = dyn_cast(V)) return C->getSplatValue(); diff --git a/lib/CodeGen/CodeGenPrepare.cpp b/lib/CodeGen/CodeGenPrepare.cpp index 3272f36a143..9a4ed2fab60 100644 --- a/lib/CodeGen/CodeGenPrepare.cpp +++ b/lib/CodeGen/CodeGenPrepare.cpp @@ -5314,88 +5314,112 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, /// zero index. bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr) { - const GetElementPtrInst *GEP = dyn_cast(Ptr); - if (!GEP || !GEP->hasIndices()) + // FIXME: Support scalable vectors. + if (isa(Ptr->getType())) return false; - // If the GEP and the gather/scatter aren't in the same BB, don't optimize. - // FIXME: We should support this by sinking the GEP. - if (MemoryInst->getParent() != GEP->getParent()) - return false; - - SmallVector Ops(GEP->op_begin(), GEP->op_end()); - - bool RewriteGEP = false; - - if (Ops[0]->getType()->isVectorTy()) { - Ops[0] = const_cast(getSplatValue(Ops[0])); - if (!Ops[0]) - return false; - RewriteGEP = true; - } - - unsigned FinalIndex = Ops.size() - 1; - - // Ensure all but the last index is 0. - // FIXME: This isn't strictly required. All that's required is that they are - // all scalars or splats. - for (unsigned i = 1; i < FinalIndex; ++i) { - auto *C = dyn_cast(Ops[i]); - if (!C) - return false; - if (isa(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null(C); - if (!CI || !CI->isZero()) - return false; - // Scalarize the index if needed. - Ops[i] = CI; - } - - // Try to scalarize the final index. - if (Ops[FinalIndex]->getType()->isVectorTy()) { - if (Value *V = const_cast(getSplatValue(Ops[FinalIndex]))) { - auto *C = dyn_cast(V); - // Don't scalarize all zeros vector. - if (!C || !C->isZero()) { - Ops[FinalIndex] = V; - RewriteGEP = true; - } - } - } - - // If we made any changes or the we have extra operands, we need to generate - // new instructions. - if (!RewriteGEP && Ops.size() == 2) - return false; - - unsigned NumElts = cast(Ptr->getType())->getNumElements(); - - IRBuilder<> Builder(MemoryInst); - - Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); - Value *NewAddr; - // If the final index isn't a vector, emit a scalar GEP containing all ops - // and a vector GEP with all zeroes final index. - if (!Ops[FinalIndex]->getType()->isVectorTy()) { - NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); - auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); - } else { - Value *Base = Ops[0]; - Value *Index = Ops[FinalIndex]; + if (const auto *GEP = dyn_cast(Ptr)) { + // Don't optimize GEPs that don't have indices. + if (!GEP->hasIndices()) + return false; - // Create a scalar GEP if there are more than 2 operands. - if (Ops.size() != 2) { - // Replace the last index with 0. - Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); - Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) + return false; + + SmallVector Ops(GEP->op_begin(), GEP->op_end()); + + bool RewriteGEP = false; + + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = getSplatValue(Ops[0]); + if (!Ops[0]) + return false; + RewriteGEP = true; } - // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(Base, Index); + unsigned FinalIndex = Ops.size() - 1; + + // Ensure all but the last index is 0. + // FIXME: This isn't strictly required. All that's required is that they are + // all scalars or splats. + for (unsigned i = 1; i < FinalIndex; ++i) { + auto *C = dyn_cast(Ops[i]); + if (!C) + return false; + if (isa(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = getSplatValue(Ops[FinalIndex])) { + auto *C = dyn_cast(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; + } + } + } + + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; + + unsigned NumElts = cast(Ptr->getType())->getNumElements(); + + IRBuilder<> Builder(MemoryInst); + + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; + + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); + Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + } + + // Now create the GEP with scalar pointer and vector index. + NewAddr = Builder.CreateGEP(Base, Index); + } + } else if (!isa(Ptr)) { + // Not a GEP, maybe its a splat and we can create a GEP to enable + // SelectionDAGBuilder to use it as a uniform base. + Value *V = getSplatValue(Ptr); + if (!V) + return false; + + unsigned NumElts = cast(Ptr->getType())->getNumElements(); + + IRBuilder<> Builder(MemoryInst); + + // Emit a vector GEP with a scalar pointer and all 0s vector index. + Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy)); + } else { + // Constant, SelectionDAGBuilder knows to check if its a splat. + return false; } MemoryInst->replaceUsesOfWith(Ptr, NewAddr); diff --git a/test/CodeGen/X86/masked_gather_scatter.ll b/test/CodeGen/X86/masked_gather_scatter.ll index c5781e83407..88418fd85fe 100644 --- a/test/CodeGen/X86/masked_gather_scatter.ll +++ b/test/CodeGen/X86/masked_gather_scatter.ll @@ -3323,14 +3323,13 @@ define void @scatter_16i64_constant_indices(i32* %ptr, <16 x i1> %mask, <16 x i3 define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) { ; KNL_64-LABEL: splat_ptr_gather: ; KNL_64: # %bb.0: -; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1 +; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 ; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0 ; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_64-NEXT: kshiftlw $12, %k0, %k0 ; KNL_64-NEXT: kshiftrw $12, %k0, %k1 -; KNL_64-NEXT: vmovq %rdi, %xmm0 -; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0 -; KNL_64-NEXT: vpgatherqd (,%zmm0), %ymm1 {%k1} +; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_64-NEXT: vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1} ; KNL_64-NEXT: vmovdqa %xmm1, %xmm0 ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq @@ -3342,8 +3341,9 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr ; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_32-NEXT: kshiftlw $12, %k0, %k0 ; KNL_32-NEXT: kshiftrw $12, %k0, %k1 -; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; KNL_32-NEXT: vpgatherdd (,%zmm0), %zmm1 {%k1} +; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_32-NEXT: vpgatherdd (%eax,%zmm0,4), %zmm1 {%k1} ; KNL_32-NEXT: vmovdqa %xmm1, %xmm0 ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl @@ -3352,18 +3352,18 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr ; SKX: # %bb.0: ; SKX-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX-NEXT: vpmovd2m %xmm0, %k1 -; SKX-NEXT: vpbroadcastq %rdi, %ymm0 -; SKX-NEXT: vpgatherqd (,%ymm0), %xmm1 {%k1} +; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1} ; SKX-NEXT: vmovdqa %xmm1, %xmm0 -; SKX-NEXT: vzeroupper ; SKX-NEXT: retq ; ; SKX_32-LABEL: splat_ptr_gather: ; SKX_32: # %bb.0: ; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX_32-NEXT: vpmovd2m %xmm0, %k1 -; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; SKX_32-NEXT: vpgatherdd (,%xmm0), %xmm1 {%k1} +; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX_32-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1} ; SKX_32-NEXT: vmovdqa %xmm1, %xmm0 ; SKX_32-NEXT: retl %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 @@ -3376,14 +3376,13 @@ declare <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>, define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; KNL_64-LABEL: splat_ptr_scatter: ; KNL_64: # %bb.0: -; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1 +; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 ; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0 ; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_64-NEXT: kshiftlw $12, %k0, %k0 ; KNL_64-NEXT: kshiftrw $12, %k0, %k1 -; KNL_64-NEXT: vmovq %rdi, %xmm0 -; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0 -; KNL_64-NEXT: vpscatterqd %ymm1, (,%zmm0) {%k1} +; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1} ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq ; @@ -3394,8 +3393,9 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_32-NEXT: kshiftlw $12, %k0, %k0 ; KNL_32-NEXT: kshiftrw $12, %k0, %k1 -; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; KNL_32-NEXT: vpscatterdd %zmm1, (,%zmm0) {%k1} +; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1} ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl ; @@ -3403,17 +3403,17 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; SKX: # %bb.0: ; SKX-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX-NEXT: vpmovd2m %xmm0, %k1 -; SKX-NEXT: vpbroadcastq %rdi, %ymm0 -; SKX-NEXT: vpscatterqd %xmm1, (,%ymm0) {%k1} -; SKX-NEXT: vzeroupper +; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: splat_ptr_scatter: ; SKX_32: # %bb.0: ; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX_32-NEXT: vpmovd2m %xmm0, %k1 -; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; SKX_32-NEXT: vpscatterdd %xmm1, (,%xmm0) {%k1} +; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1} ; SKX_32-NEXT: retl %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer diff --git a/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll index c1674ad4ca4..adb1930ca78 100644 --- a/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll +++ b/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll @@ -87,10 +87,9 @@ define <4 x i32> @global_struct_splat() { define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) { ; CHECK-LABEL: @splat_ptr_gather( -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]]) -; CHECK-NEXT: ret <4 x i32> [[TMP3]] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]]) +; CHECK-NEXT: ret <4 x i32> [[TMP2]] ; %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer @@ -100,9 +99,8 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; CHECK-LABEL: @splat_ptr_scatter( -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) ; CHECK-NEXT: ret void ; %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0