1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-23 03:02:36 +01:00

[NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero

For now, we correct the result for sqrt if iteration > 0. This doesn't make
sense as they are not strict relative.

Reviewed By: dmgreen, spatel, RKSimon

Differential Revision: https://reviews.llvm.org/D94480
This commit is contained in:
QingShan Zhang 2021-01-25 04:00:32 +00:00
parent ac191922c9
commit d5b70bbb38
6 changed files with 56 additions and 47 deletions

View File

@ -4287,9 +4287,7 @@ public:
/// comparison may check if the operand is NAN, INF, zero, normal, etc. The
/// result should be used as the condition operand for a select or branch.
virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
const DenormalMode &Mode) const {
return SDValue();
}
const DenormalMode &Mode) const;
/// Return a target-dependent result if the input operand is not suitable for
/// use with a square root estimate calculation.

View File

@ -22275,43 +22275,21 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
Reciprocal)) {
AddToWorklist(Est.getNode());
if (Iterations) {
if (Iterations)
Est = UseOneConstNR
? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
: buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
if (!Reciprocal) {
SDLoc DL(Op);
// Try the target specific test first.
SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
if (!Reciprocal) {
SDLoc DL(Op);
EVT CCVT = getSetCCResultType(VT);
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
DenormalMode DenormMode = DAG.getDenormalMode(VT);
// Try the target specific test first.
SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode);
if (!Test) {
// If no test provided by target, testing it with denormal inputs to
// avoid wrong estimate.
if (DenormMode.Input == DenormalMode::IEEE) {
// This is specifically a check for the handling of denormal inputs,
// not the result.
// Test = fabs(X) < SmallestNormal
const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
} else
// Test = X == 0.0
Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
}
// The estimate is now completely wrong if the input was exactly 0.0 or
// possibly a denormal. Force the answer to 0.0 or value provided by
// target for those cases.
Est = DAG.getNode(
Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
}
// The estimate is now completely wrong if the input was exactly 0.0 or
// possibly a denormal. Force the answer to 0.0 or value provided by
// target for those cases.
Est = DAG.getNode(
Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
}
return Est;
}

View File

@ -5841,6 +5841,28 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
return false;
}
SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
const DenormalMode &Mode) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
// Testing it with denormal inputs to avoid wrong estimate.
if (Mode.Input == DenormalMode::IEEE) {
// This is specifically a check for the handling of denormal inputs,
// not the result.
// Test = fabs(X) < SmallestNormal
const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
}
// Test = X == 0.0
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
}
SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
bool LegalOps, bool OptForSize,
NegatibleCost &Cost,

View File

@ -7471,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
return SDValue();
}
SDValue
AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
const DenormalMode &Mode) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
}
SDValue
AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
SelectionDAG &DAG) const {
return Op;
}
SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
SelectionDAG &DAG, int Enabled,
int &ExtraSteps,
@ -7494,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
}
if (!Reciprocal) {
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
VT);
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
if (!Reciprocal)
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
// Correct the result if the operand is 0.0.
Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
VT, Eq, Operand, Estimate);
}
ExtraSteps = 0;
return Estimate;

View File

@ -961,6 +961,10 @@ private:
bool Reciprocal) const override;
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
int &ExtraSteps) const override;
SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
const DenormalMode &Mode) const override;
SDValue getSqrtResultForDenormInput(SDValue Operand,
SelectionDAG &DAG) const override;
unsigned combineRepeatedFPDivisors() const override;
ConstraintType getConstraintType(StringRef Constraint) const override;

View File

@ -12133,7 +12133,7 @@ SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
if (!isTypeLegal(MVT::i1) ||
(VT != MVT::f64 &&
((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX())))
return SDValue();
return TargetLowering::getSqrtInputTest(Op, DAG, Mode);
SDLoc DL(Op);
// The output register of FTSQRT is CR field.