From cf6bdfc0260b441fd889103b002c7413c428c79b Mon Sep 17 00:00:00 2001 From: Fraser Cormack Date: Fri, 23 Jul 2021 11:13:08 +0100 Subject: [PATCH] [SelectionDAG] Support scalable splats in U(ADD|SUB)SAT combines This patch builds on top of D106575 in which scalable-vector splats were supported in `ISD::matchBinaryPredicate`. It teaches the DAGCombiner how to perform a variety of the pre-existing saturating add/sub combines on scalable-vector types. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D106652 --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 36 +++++++++------- test/CodeGen/RISCV/rvv/combine-sats.ll | 54 ++++++------------------ 2 files changed, 34 insertions(+), 56 deletions(-) diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index db4685944b7..182c29eea7c 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10025,10 +10025,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { // If it's on the left side invert the predicate to simplify logic below. SDValue Other; ISD::CondCode SatCC = CC; - if (ISD::isBuildVectorAllOnes(N1.getNode())) { + if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) { Other = N2; SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType()); - } else if (ISD::isBuildVectorAllOnes(N2.getNode())) { + } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) { Other = N1; } @@ -10049,7 +10049,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { (OpLHS == CondLHS || OpRHS == CondLHS)) return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS); - if (isa(OpRHS) && isa(CondRHS) && + if (OpRHS.getOpcode() == CondRHS.getOpcode() && + (OpRHS.getOpcode() == ISD::BUILD_VECTOR || + OpRHS.getOpcode() == ISD::SPLAT_VECTOR) && CondLHS == OpLHS) { // If the RHS is a constant we have to reverse the const // canonicalization. @@ -10070,10 +10072,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { // the left side invert the predicate to simplify logic below. SDValue Other; ISD::CondCode SatCC = CC; - if (ISD::isBuildVectorAllZeros(N1.getNode())) { + if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) { Other = N2; SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType()); - } else if (ISD::isBuildVectorAllZeros(N2.getNode())) { + } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) { Other = N1; } @@ -10102,8 +10104,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { Other.getOpcode() == ISD::SUB && OpRHS == CondRHS) return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); - if (auto *OpRHSBV = dyn_cast(OpRHS)) { - if (isa(CondRHS)) { + if (OpRHS.getOpcode() == ISD::BUILD_VECTOR || + OpRHS.getOpcode() == ISD::SPLAT_VECTOR) { + if (CondRHS.getOpcode() == ISD::BUILD_VECTOR || + CondRHS.getOpcode() == ISD::SPLAT_VECTOR) { // If the RHS is a constant we have to reverse the const // canonicalization. // x > C-1 ? x+-C : 0 --> usubsat x, C @@ -10125,15 +10129,15 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { // FIXME: Would it be better to use computeKnownBits to determine // whether it's safe to decanonicalize the xor? // x s< 0 ? x^C : 0 --> usubsat x, C - if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) { - if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllZeros(CondRHS.getNode()) && - OpRHSConst->getAPIntValue().isSignMask()) { - // Note that we have to rebuild the RHS constant here to - // ensure we don't rely on particular values of undef lanes. - OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT); - return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); - } + APInt SplatValue; + if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR && + ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) && + ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) && + SplatValue.isSignMask()) { + // Note that we have to rebuild the RHS constant here to + // ensure we don't rely on particular values of undef lanes. + OpRHS = DAG.getConstant(SplatValue, DL, VT); + return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); } } } diff --git a/test/CodeGen/RISCV/rvv/combine-sats.ll b/test/CodeGen/RISCV/rvv/combine-sats.ll index 197e00bb947..ada166fa672 100644 --- a/test/CodeGen/RISCV/rvv/combine-sats.ll +++ b/test/CodeGen/RISCV/rvv/combine-sats.ll @@ -101,10 +101,7 @@ define @vselect_sub_nxv2i64( %a0, %a0, %a1 %v1 = sub %a0, %a1 @@ -131,9 +128,7 @@ define @vselect_sub_2_nxv8i16( %x, i16 zero ; CHECK-LABEL: vselect_sub_2_nxv8i16: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu -; CHECK-NEXT: vmsltu.vx v0, v8, a0 -; CHECK-NEXT: vsub.vx v26, v8, a0 -; CHECK-NEXT: vmerge.vim v8, v26, 0, v0 +; CHECK-NEXT: vssubu.vx v8, v8, a0 ; CHECK-NEXT: ret entry: %0 = insertelement undef, i16 %w, i32 0 @@ -163,11 +158,9 @@ define <2 x i64> @vselect_add_const_v2i64(<2 x i64> %a0) { define @vselect_add_const_nxv2i64( %a0) { ; CHECK-LABEL: vselect_add_const_nxv2i64: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu -; CHECK-NEXT: vadd.vi v26, v8, -6 -; CHECK-NEXT: vmsgtu.vi v0, v8, 5 -; CHECK-NEXT: vmv.v.i v28, 0 -; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0 +; CHECK-NEXT: addi a0, zero, 6 +; CHECK-NEXT: vsetvli a1, zero, e64, m2, ta, mu +; CHECK-NEXT: vssubu.vx v8, v8, a0 ; CHECK-NEXT: ret %cm1 = insertelement poison, i64 -6, i32 0 %splatcm1 = shufflevector %cm1, poison, zeroinitializer @@ -205,27 +198,17 @@ define @vselect_add_const_signbit_nxv2i16( ; RV32-LABEL: vselect_add_const_signbit_nxv2i16: ; RV32: # %bb.0: ; RV32-NEXT: lui a0, 8 -; RV32-NEXT: addi a0, a0, -2 +; RV32-NEXT: addi a0, a0, -1 ; RV32-NEXT: vsetvli a1, zero, e16, mf2, ta, mu -; RV32-NEXT: vmsgtu.vx v0, v8, a0 -; RV32-NEXT: lui a0, 1048568 -; RV32-NEXT: addi a0, a0, 1 -; RV32-NEXT: vadd.vx v25, v8, a0 -; RV32-NEXT: vmv.v.i v26, 0 -; RV32-NEXT: vmerge.vvm v8, v26, v25, v0 +; RV32-NEXT: vssubu.vx v8, v8, a0 ; RV32-NEXT: ret ; ; RV64-LABEL: vselect_add_const_signbit_nxv2i16: ; RV64: # %bb.0: ; RV64-NEXT: lui a0, 8 -; RV64-NEXT: addiw a0, a0, -2 +; RV64-NEXT: addiw a0, a0, -1 ; RV64-NEXT: vsetvli a1, zero, e16, mf2, ta, mu -; RV64-NEXT: vmsgtu.vx v0, v8, a0 -; RV64-NEXT: lui a0, 1048568 -; RV64-NEXT: addiw a0, a0, 1 -; RV64-NEXT: vadd.vx v25, v8, a0 -; RV64-NEXT: vmv.v.i v26, 0 -; RV64-NEXT: vmerge.vvm v8, v26, v25, v0 +; RV64-NEXT: vssubu.vx v8, v8, a0 ; RV64-NEXT: ret %cm1 = insertelement poison, i16 32766, i32 0 %splatcm1 = shufflevector %cm1, poison, zeroinitializer @@ -255,12 +238,9 @@ define <2 x i16> @vselect_xor_const_signbit_v2i16(<2 x i16> %a0) { define @vselect_xor_const_signbit_nxv2i16( %a0) { ; CHECK-LABEL: vselect_xor_const_signbit_nxv2i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu -; CHECK-NEXT: vmsle.vi v0, v8, -1 -; CHECK-NEXT: vmv.v.i v25, 0 -; CHECK-NEXT: lui a0, 1048568 -; CHECK-NEXT: vxor.vx v26, v8, a0 -; CHECK-NEXT: vmerge.vvm v8, v25, v26, v0 +; CHECK-NEXT: lui a0, 8 +; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, mu +; CHECK-NEXT: vssubu.vx v8, v8, a0 ; CHECK-NEXT: ret %cmp = icmp slt %a0, zeroinitializer %ins = insertelement poison, i16 -32768, i32 0 @@ -291,10 +271,7 @@ define @vselect_add_nxv2i64( %a0, %a0, %a1 %cmp = icmp ule %a0, %v1 @@ -323,10 +300,7 @@ define @vselect_add_const_2_nxv2i64( %a0) { ; CHECK-LABEL: vselect_add_const_2_nxv2i64: ; CHECK: # %bb.0: ; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu -; CHECK-NEXT: vadd.vi v26, v8, 6 -; CHECK-NEXT: vmsleu.vi v0, v8, -7 -; CHECK-NEXT: vmv.v.i v28, -1 -; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0 +; CHECK-NEXT: vsaddu.vi v8, v8, 6 ; CHECK-NEXT: ret %cm1 = insertelement poison, i64 6, i32 0 %splatcm1 = shufflevector %cm1, poison, zeroinitializer