mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 02:33:06 +01:00
[Matrix] RAUW should only replace an instruction in ShapeMap if supportsShapeInfo
As an instruction is replaced in optimizeTransposes RAUW will replace it in the ShapeMap (ShapeMap is ValueMap so that uses are updated). In finalizeLowering however we skip updating uses if they are in the ShapeMap since they will be lowered separately at which point we pick up the lowered operands. In the testcase what happened was that since we replaced the doubled-transpose with the shuffle, it ended up in the ShapeMap. As we lowered the columnwise-load the use in the shuffle was not updated. Then as we removed the original columnwise-load we changed that to an undef. I.e. we ended up with: ``` %shuf = shufflevector <8 x double> undef, <8 x double> poison, <6 x i32> ^^^^^ <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6> ``` Besides the fix itself, I have fortified this last bit. As we change uses to undef when removing instruction we track the undefed instruction to make sure we eventually remove those too. This would have caught the issue at compile time. Differential Revision: https://reviews.llvm.org/D106714
This commit is contained in:
parent
6516543c4b
commit
add64be20a
@ -685,6 +685,19 @@ public:
|
||||
|
||||
/// Try moving transposes in order to fold them away or into multiplies.
|
||||
void optimizeTransposes() {
|
||||
auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
|
||||
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
|
||||
// with New. We should only add New it it supportsShapeInfo so we insert
|
||||
// it conditionally instead.
|
||||
auto S = ShapeMap.find(&Old);
|
||||
if (S != ShapeMap.end()) {
|
||||
ShapeMap.erase(S);
|
||||
if (supportsShapeInfo(New))
|
||||
ShapeMap.insert({New, S->second});
|
||||
}
|
||||
Old.replaceAllUsesWith(New);
|
||||
};
|
||||
|
||||
// First sink all transposes inside matmuls, hoping that we end up with NN,
|
||||
// NT or TN variants.
|
||||
for (BasicBlock &BB : reverse(Func)) {
|
||||
@ -717,7 +730,7 @@ public:
|
||||
Value *TATA;
|
||||
if (match(TA,
|
||||
m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
|
||||
I.replaceAllUsesWith(TATA);
|
||||
ReplaceAllUsesWith(I, TATA);
|
||||
EraseFromParent(&I);
|
||||
EraseFromParent(TA);
|
||||
}
|
||||
@ -740,8 +753,7 @@ public:
|
||||
NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
|
||||
K->getZExtValue(),
|
||||
R->getZExtValue(), "mmul");
|
||||
setShapeInfo(NewInst, {C, R});
|
||||
I.replaceAllUsesWith(NewInst);
|
||||
ReplaceAllUsesWith(I, NewInst);
|
||||
EraseFromParent(&I);
|
||||
EraseFromParent(TA);
|
||||
}
|
||||
@ -774,8 +786,7 @@ public:
|
||||
setShapeInfo(M, {C, R});
|
||||
Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
|
||||
C->getZExtValue());
|
||||
setShapeInfo(NewInst, {C, R});
|
||||
I->replaceAllUsesWith(NewInst);
|
||||
ReplaceAllUsesWith(*I, NewInst);
|
||||
if (I->use_empty())
|
||||
I->eraseFromParent();
|
||||
if (A->use_empty())
|
||||
@ -879,10 +890,30 @@ public:
|
||||
|
||||
// Delete the instructions backwards, as it has a reduced likelihood of
|
||||
// having to update as many def-use and use-def chains.
|
||||
//
|
||||
// Because we add to ToRemove during fusion we can't guarantee that defs
|
||||
// are before uses. Change uses to undef temporarily as these should get
|
||||
// removed as well.
|
||||
//
|
||||
// For verification, we keep track of where we changed uses to undefs in
|
||||
// UndefedInsts and then check that we in fact remove them.
|
||||
SmallSet<Instruction *, 16> UndefedInsts;
|
||||
for (auto *Inst : reverse(ToRemove)) {
|
||||
if (!Inst->use_empty())
|
||||
Inst->replaceAllUsesWith(UndefValue::get(Inst->getType()));
|
||||
for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
|
||||
Use &U = *I++;
|
||||
if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
|
||||
UndefedInsts.insert(Undefed);
|
||||
U.set(UndefValue::get(Inst->getType()));
|
||||
}
|
||||
Inst->eraseFromParent();
|
||||
UndefedInsts.erase(Inst);
|
||||
}
|
||||
if (!UndefedInsts.empty()) {
|
||||
// If we didn't remove all undefed instructions, it's a hard error.
|
||||
dbgs() << "Undefed but present instructions:\n";
|
||||
for (auto *I : UndefedInsts)
|
||||
dbgs() << *I << "\n";
|
||||
llvm_unreachable("Undefed but instruction not removed");
|
||||
}
|
||||
|
||||
return Changed;
|
||||
|
@ -986,6 +986,32 @@ entry:
|
||||
ret <4 x float> %m
|
||||
}
|
||||
|
||||
define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
|
||||
; CHECK-LABEL: @transpose_of_transpose_of_non_matrix_op(
|
||||
; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[A:%.*]] to <2 x double>*
|
||||
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
|
||||
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[A]], i64 4
|
||||
; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
|
||||
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
|
||||
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[A]], i64 8
|
||||
; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
|
||||
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
|
||||
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[A]], i64 12
|
||||
; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
|
||||
; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x double> [[COL_LOAD5]], <2 x double> [[COL_LOAD8]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP2]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
|
||||
; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <8 x double> [[TMP3]], <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
|
||||
; CHECK-NEXT: ret <6 x double> [[SHUF]]
|
||||
;
|
||||
%load = call <8 x double> @llvm.matrix.column.major.load.v8f64(double* %a, i64 4, i1 false, i32 2, i32 4)
|
||||
%shuf = shufflevector <8 x double> %load, <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
|
||||
%t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %shuf, i32 3, i32 2)
|
||||
%tt = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %t, i32 2, i32 3)
|
||||
ret <6 x double> %tt
|
||||
}
|
||||
|
||||
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
|
||||
declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
|
||||
declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)
|
||||
@ -995,3 +1021,4 @@ declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
|
||||
declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32)
|
||||
declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32)
|
||||
declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32)
|
||||
declare <8 x double> @llvm.matrix.column.major.load.v8f64(double*, i64, i1, i32, i32)
|
||||
|
Loading…
Reference in New Issue
Block a user