1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 20:23:11 +01:00

[ARM] Fix crash in chained BFI combine due to incorrectly RAUW'ing a node.

For a bfi chain like:
a = bfi input, x, y
b = bfi a, x', y'

The previous code was RAUW'ing a with x, mutating the second 'b' bfi, and when
SelectionDAG's CSE code ended up deleting it unexpectedly, bad things happend.
There's no need to RAUW in this case because we can just return our newly
created replacement BFI node. It also looked incorrect because it didn't account
for other users of the 'a' bfi.

Since it seems that chains of more than 2 BFI nodes are hard/impossible to
produce without this combine kicking in at some point, I've removed that
functionality since it had no test coverage.

rdar://79095399

Differential Revision: https://reviews.llvm.org/D104868
This commit is contained in:
Amara Emerson 2021-06-24 11:10:42 -07:00
parent 98c72058c5
commit e4692db44e
2 changed files with 81 additions and 58 deletions

View File

@ -13954,45 +13954,32 @@ static bool BitsProperlyConcatenate(const APInt &A, const APInt &B) {
} }
static SDValue FindBFIToCombineWith(SDNode *N) { static SDValue FindBFIToCombineWith(SDNode *N) {
// We have a BFI in N. Follow a possible chain of BFIs and find a BFI it can combine with, // We have a BFI in N. Find a BFI it can combine with, if one exists.
// if one exists.
APInt ToMask, FromMask; APInt ToMask, FromMask;
SDValue From = ParseBFI(N, ToMask, FromMask); SDValue From = ParseBFI(N, ToMask, FromMask);
SDValue To = N->getOperand(0); SDValue To = N->getOperand(0);
// Now check for a compatible BFI to merge with. We can pass through BFIs that
// aren't compatible, but not if they set the same bit in their destination as
// we do (or that of any BFI we're going to combine with).
SDValue V = To; SDValue V = To;
APInt CombinedToMask = ToMask; if (V.getOpcode() != ARMISD::BFI)
while (V.getOpcode() == ARMISD::BFI) { return SDValue();
APInt NewToMask, NewFromMask;
SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask);
if (NewFrom != From) {
// This BFI has a different base. Keep going.
CombinedToMask |= NewToMask;
V = V.getOperand(0);
continue;
}
// Do the written bits conflict with any we've seen so far? APInt NewToMask, NewFromMask;
if ((NewToMask & CombinedToMask).getBoolValue()) SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask);
// Conflicting bits - bail out because going further is unsafe. if (NewFrom != From)
return SDValue(); return SDValue();
// Are the new bits contiguous when combined with the old bits? // Do the written bits conflict with any we've seen so far?
if (BitsProperlyConcatenate(ToMask, NewToMask) && if ((NewToMask & ToMask).getBoolValue())
BitsProperlyConcatenate(FromMask, NewFromMask)) // Conflicting bits.
return V; return SDValue();
if (BitsProperlyConcatenate(NewToMask, ToMask) &&
BitsProperlyConcatenate(NewFromMask, FromMask))
return V;
// We've seen a write to some bits, so track it. // Are the new bits contiguous when combined with the old bits?
CombinedToMask |= NewToMask; if (BitsProperlyConcatenate(ToMask, NewToMask) &&
// Keep going... BitsProperlyConcatenate(FromMask, NewFromMask))
V = V.getOperand(0); return V;
} if (BitsProperlyConcatenate(NewToMask, ToMask) &&
BitsProperlyConcatenate(NewFromMask, FromMask))
return V;
return SDValue(); return SDValue();
} }
@ -14018,40 +14005,35 @@ static SDValue PerformBFICombine(SDNode *N,
return DCI.DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0), return DCI.DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0),
N->getOperand(0), N1.getOperand(0), N->getOperand(0), N1.getOperand(0),
N->getOperand(2)); N->getOperand(2));
} else if (N->getOperand(0).getOpcode() == ARMISD::BFI) { return SDValue();
// We have a BFI of a BFI. Walk up the BFI chain to see how long it goes. }
// Keep track of any consecutive bits set that all come from the same base // Look for another BFI to combine with.
// value. We can combine these together into a single BFI. SDValue CombineBFI = FindBFIToCombineWith(N);
SDValue CombineBFI = FindBFIToCombineWith(N); if (CombineBFI == SDValue())
if (CombineBFI == SDValue()) return SDValue();
return SDValue();
// We've found a BFI. // We've found a BFI.
APInt ToMask1, FromMask1; APInt ToMask1, FromMask1;
SDValue From1 = ParseBFI(N, ToMask1, FromMask1); SDValue From1 = ParseBFI(N, ToMask1, FromMask1);
APInt ToMask2, FromMask2; APInt ToMask2, FromMask2;
SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2); SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2);
assert(From1 == From2); assert(From1 == From2);
(void)From2; (void)From2;
// First, unlink CombineBFI. // Create a new BFI, combining the two together.
DCI.DAG.ReplaceAllUsesWith(CombineBFI, CombineBFI.getOperand(0)); APInt NewFromMask = FromMask1 | FromMask2;
// Then create a new BFI, combining the two together. APInt NewToMask = ToMask1 | ToMask2;
APInt NewFromMask = FromMask1 | FromMask2;
APInt NewToMask = ToMask1 | ToMask2;
EVT VT = N->getValueType(0); EVT VT = N->getValueType(0);
SDLoc dl(N); SDLoc dl(N);
if (NewFromMask[0] == 0) if (NewFromMask[0] == 0)
From1 = DCI.DAG.getNode( From1 = DCI.DAG.getNode(
ISD::SRL, dl, VT, From1, ISD::SRL, dl, VT, From1,
DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT)); DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT));
return DCI.DAG.getNode(ARMISD::BFI, dl, VT, N->getOperand(0), From1, return DCI.DAG.getNode(ARMISD::BFI, dl, VT, CombineBFI.getOperand(0), From1,
DCI.DAG.getConstant(~NewToMask, dl, VT)); DCI.DAG.getConstant(~NewToMask, dl, VT));
}
return SDValue();
} }
/// PerformVMOVRRDCombine - Target-specific dag combine xforms for /// PerformVMOVRRDCombine - Target-specific dag combine xforms for

View File

@ -0,0 +1,41 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=thumbv7s | FileCheck %s
target datalayout = "e-m:o-p:32:32-f64:32:64-v64:32:64-v128:32:128-a:0:32-n32-S32"
target triple = "thumbv7s-apple-ios3.1.3"
define void @bfi_chain_cse_crash(i8* %0, i8 *%ptr) {
; CHECK-LABEL: bfi_chain_cse_crash:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: ldrb r2, [r0]
; CHECK-NEXT: and r3, r2, #1
; CHECK-NEXT: lsr.w r12, r2, #3
; CHECK-NEXT: bfi r3, r12, #3, #1
; CHECK-NEXT: strb r3, [r0]
; CHECK-NEXT: and r0, r2, #4
; CHECK-NEXT: bfi r0, r12, #3, #1
; CHECK-NEXT: strb r0, [r1]
; CHECK-NEXT: bx lr
entry:
%1 = load i8, i8* %0, align 1
%2 = and i8 %1, 1
%3 = select i1 false, i8 %2, i8 0
%4 = and i8 %1, 4
%5 = icmp eq i8 %4, 0
%6 = zext i8 %3 to i32
%7 = or i32 %6, 4
%8 = trunc i32 %7 to i8
%9 = select i1 %5, i8 %3, i8 %8
%10 = and i8 %1, 8
%11 = icmp eq i8 %10, 0
%12 = zext i8 %2 to i32
%13 = or i32 %12, 8
%14 = trunc i32 %13 to i8
%15 = zext i8 %9 to i32
%16 = or i32 %15, 8
%17 = trunc i32 %16 to i8
%18 = select i1 %11, i8 %2, i8 %14
%19 = select i1 %11, i8 %9, i8 %17
store i8 %18, i8* %0, align 1
store i8 %19, i8* %ptr, align 1
ret void
}