1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-26 12:43:36 +01:00

[DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2

Follow-up to D82716 / rGea71ba11ab11
We do not have the fabs removal fold in IR yet for the case
where the sqrt operand is repeated, so that's another potential
improvement.
This commit is contained in:
Sanjay Patel 2020-08-07 16:57:27 -04:00
parent d58ed6d476
commit 33d6e8d6a8
2 changed files with 67 additions and 54 deletions

View File

@ -13313,21 +13313,26 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
} }
if (Sqrt.getNode()) { if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the // If the other multiply operand is known positive, pull it into the
// sqrt. That will eliminate the division if we convert to an estimate: // sqrt. That will eliminate the division if we convert to an estimate.
// X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
// TODO: Also fold the case where A == Z (fabs is missing).
if (Flags.hasAllowReassociation() && N1.hasOneUse() && if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() && N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
Y.getOpcode() == ISD::FABS && Y.hasOneUse()) { SDValue A;
SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0), if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
Y.getOperand(0), Flags); A = Y.getOperand(0);
SDValue AAZ = else if (Y == Sqrt.getOperand(0))
DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags); A = Y;
if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags)) if (A) {
return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags); // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
// X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A, Flags);
SDValue AAZ =
DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
// Estimate creation failed. Clean up speculatively created nodes. // Estimate creation failed. Clean up speculatively created nodes.
recursivelyDeleteUnusedNodes(AAZ.getNode()); recursivelyDeleteUnusedNodes(AAZ.getNode());
}
} }
// We found a FSQRT, so try to make this fold: // We found a FSQRT, so try to make this fold:

View File

@ -803,38 +803,43 @@ define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
define float @div_sqrt_f32(float %x, float %y) { define float @div_sqrt_f32(float %x, float %y) {
; SSE-LABEL: div_sqrt_f32: ; SSE-LABEL: div_sqrt_f32:
; SSE: # %bb.0: ; SSE: # %bb.0:
; SSE-NEXT: rsqrtss %xmm1, %xmm2 ; SSE-NEXT: movaps %xmm1, %xmm2
; SSE-NEXT: movaps %xmm1, %xmm3 ; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm2, %xmm3 ; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm2, %xmm3 ; SSE-NEXT: xorps %xmm1, %xmm1
; SSE-NEXT: addss {{.*}}(%rip), %xmm3 ; SSE-NEXT: rsqrtss %xmm2, %xmm1
; SSE-NEXT: mulss {{.*}}(%rip), %xmm2 ; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm3, %xmm2 ; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: divss %xmm1, %xmm2 ; SSE-NEXT: addss {{.*}}(%rip), %xmm2
; SSE-NEXT: mulss %xmm2, %xmm0 ; SSE-NEXT: mulss {{.*}}(%rip), %xmm1
; SSE-NEXT: mulss %xmm0, %xmm1
; SSE-NEXT: mulss %xmm2, %xmm1
; SSE-NEXT: movaps %xmm1, %xmm0
; SSE-NEXT: retq ; SSE-NEXT: retq
; ;
; AVX1-LABEL: div_sqrt_f32: ; AVX1-LABEL: div_sqrt_f32:
; AVX1: # %bb.0: ; AVX1: # %bb.0:
; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2 ; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm3 ; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm3 ; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm3, %xmm3 ; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2 ; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2 ; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1 ; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq ; AVX1-NEXT: retq
; ;
; AVX512-LABEL: div_sqrt_f32: ; AVX512-LABEL: div_sqrt_f32:
; AVX512: # %bb.0: ; AVX512: # %bb.0:
; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2 ; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm3 ; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm3 = (xmm2 * xmm3) + mem ; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2 ; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2 ; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1 ; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq ; AVX512-NEXT: retq
%s = call fast float @llvm.sqrt.f32(float %y) %s = call fast float @llvm.sqrt.f32(float %y)
%m = fmul fast float %s, %y %m = fmul fast float %s, %y
@ -850,39 +855,42 @@ define float @div_sqrt_f32(float %x, float %y) {
define <4 x float> @div_sqrt_v4f32(<4 x float> %x, <4 x float> %y) { define <4 x float> @div_sqrt_v4f32(<4 x float> %x, <4 x float> %y) {
; SSE-LABEL: div_sqrt_v4f32: ; SSE-LABEL: div_sqrt_v4f32:
; SSE: # %bb.0: ; SSE: # %bb.0:
; SSE-NEXT: rsqrtps %xmm1, %xmm2 ; SSE-NEXT: movaps %xmm1, %xmm2
; SSE-NEXT: movaps %xmm1, %xmm3 ; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm2, %xmm3 ; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm2, %xmm3 ; SSE-NEXT: rsqrtps %xmm2, %xmm1
; SSE-NEXT: addps {{.*}}(%rip), %xmm3 ; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps {{.*}}(%rip), %xmm2 ; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm3, %xmm2 ; SSE-NEXT: addps {{.*}}(%rip), %xmm2
; SSE-NEXT: divps %xmm1, %xmm2 ; SSE-NEXT: mulps {{.*}}(%rip), %xmm1
; SSE-NEXT: mulps %xmm2, %xmm0 ; SSE-NEXT: mulps %xmm2, %xmm1
; SSE-NEXT: mulps %xmm1, %xmm0
; SSE-NEXT: retq ; SSE-NEXT: retq
; ;
; AVX1-LABEL: div_sqrt_v4f32: ; AVX1-LABEL: div_sqrt_v4f32:
; AVX1: # %bb.0: ; AVX1: # %bb.0:
; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtps %xmm1, %xmm2 ; AVX1-NEXT: vrsqrtps %xmm1, %xmm2
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm3 ; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm3 ; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm3, %xmm3 ; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2 ; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2 ; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0 ; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq ; AVX1-NEXT: retq
; ;
; AVX512-LABEL: div_sqrt_v4f32: ; AVX512-LABEL: div_sqrt_v4f32:
; AVX512: # %bb.0: ; AVX512: # %bb.0:
; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtps %xmm1, %xmm2 ; AVX512-NEXT: vrsqrtps %xmm1, %xmm2
; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm3 ; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0] ; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm2 * xmm3) + xmm4 ; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] ; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2 ; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2 ; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0 ; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq ; AVX512-NEXT: retq
%s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %y) %s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %y)