1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-26 04:32:44 +01:00

[DAGCombiner] tighten fast-math constraints for fma fold

fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)

This is only allowed when "reassoc" is present on the fadd.

As discussed in D80801, this transform goes beyond
what is allowed by "contract" FMF (-ffp-contract=fast).
That is because we are fusing the trailing add of 'E' with a
multiply, but without "reassoc", the code mandates that the
products A*B and C*D are added together before adding in 'E'.

I've added this example to the LangRef to try to clarify the
meaning of "contract". If that seems reasonable, we should
probably do something similar for the clang docs because
there does not appear to be any formal spec for the behavior
of -ffp-contract=fast.

Differential Revision: https://reviews.llvm.org/D82499
This commit is contained in:
Sanjay Patel 2020-07-12 08:51:49 -04:00
parent 55e235a8c2
commit 9bc8c780ab
4 changed files with 28 additions and 11 deletions

View File

@ -2778,7 +2778,9 @@ floating-point transformations.
``contract``
Allow floating-point contraction (e.g. fusing a multiply followed by an
addition into a fused multiply-and-add).
addition into a fused multiply-and-add). This does not enable reassociating
to form arbitrary contractions. For example, ``(a*b) + (c*d) + e`` can not
be transformed into ``(a*b) + ((c*d) + e)`` to create two fma operations.
``afn``
Approximate functions - Allow substitution of approximate calculations for

View File

@ -11986,6 +11986,8 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDNodeFlags Flags = N->getFlags();
bool CanFuse = Options.UnsafeFPMath || isContractable(N);
bool CanReassociate =
Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
CanFuse || HasFMAD);
// If the addition is not contractable, do not combine.
@ -12028,13 +12030,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
// fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
// This requires reassociation because it changes the order of operations.
SDValue FMA, E;
if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
N0.getOperand(2).hasOneUse()) {
FMA = N0;
E = N1;
} else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
} else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
N1.getOperand(2).hasOneUse()) {
FMA = N1;

View File

@ -207,6 +207,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
ret double %a2
}
; Minimum FMF - the 1st fadd is contracted because that combines
; fmul+fadd as specified by the order of operations; the 2nd fadd
; requires reassociation to fuse with c*d.
define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
; CHECK-LABEL: fadd_fma_fmul_fmf:
; CHECK: // %bb.0:
@ -220,13 +224,14 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
ret float %a2
}
; Minimum FMF, commute final add operands, change type.
; Not minimum FMF.
define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
; CHECK-LABEL: fadd_fma_fmul_2:
; CHECK: // %bb.0:
; CHECK-NEXT: fmadd s2, s2, s3, s4
; CHECK-NEXT: fmul s2, s2, s3
; CHECK-NEXT: fmadd s0, s0, s1, s2
; CHECK-NEXT: fadd s0, s4, s0
; CHECK-NEXT: ret
%m1 = fmul float %a, %b
%m2 = fmul float %c, %d

View File

@ -1821,6 +1821,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
ret double %a2
}
; Minimum FMF - the 1st fadd is contracted because that combines
; fmul+fadd as specified by the order of operations; the 2nd fadd
; requires reassociation to fuse with c*d.
define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
; FMA-LABEL: fadd_fma_fmul_fmf:
; FMA: # %bb.0:
@ -1846,25 +1850,28 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
ret float %a2
}
; Minimum FMF, commute final add operands, change type.
; Not minimum FMF.
define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
; FMA-LABEL: fadd_fma_fmul_2:
; FMA: # %bb.0:
; FMA-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
; FMA-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
; FMA-NEXT: vmulss %xmm3, %xmm2, %xmm2
; FMA-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
; FMA-NEXT: vaddss %xmm2, %xmm4, %xmm0
; FMA-NEXT: retq
;
; FMA4-LABEL: fadd_fma_fmul_2:
; FMA4: # %bb.0:
; FMA4-NEXT: vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
; FMA4-NEXT: vmulss %xmm3, %xmm2, %xmm2
; FMA4-NEXT: vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
; FMA4-NEXT: vaddss %xmm0, %xmm4, %xmm0
; FMA4-NEXT: retq
;
; AVX512-LABEL: fadd_fma_fmul_2:
; AVX512: # %bb.0:
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX512-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
; AVX512-NEXT: vaddss %xmm2, %xmm4, %xmm0
; AVX512-NEXT: retq
%m1 = fmul float %a, %b
%m2 = fmul float %c, %d