From add64be20a055e7e29265f5adc309b0532930ba4 Mon Sep 17 00:00:00 2001 From: Adam Nemet Date: Fri, 23 Jul 2021 12:27:25 -0700 Subject: [PATCH] [Matrix] RAUW should only replace an instruction in ShapeMap if supportsShapeInfo As an instruction is replaced in optimizeTransposes RAUW will replace it in the ShapeMap (ShapeMap is ValueMap so that uses are updated). In finalizeLowering however we skip updating uses if they are in the ShapeMap since they will be lowered separately at which point we pick up the lowered operands. In the testcase what happened was that since we replaced the doubled-transpose with the shuffle, it ended up in the ShapeMap. As we lowered the columnwise-load the use in the shuffle was not updated. Then as we removed the original columnwise-load we changed that to an undef. I.e. we ended up with: ``` %shuf = shufflevector <8 x double> undef, <8 x double> poison, <6 x i32> ^^^^^ ``` Besides the fix itself, I have fortified this last bit. As we change uses to undef when removing instruction we track the undefed instruction to make sure we eventually remove those too. This would have caught the issue at compile time. Differential Revision: https://reviews.llvm.org/D106714 --- .../Scalar/LowerMatrixIntrinsics.cpp | 45 ++++++++++++++++--- .../LowerMatrixIntrinsics/transpose-opts.ll | 27 +++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8030a613c81..ab75cd3f566 100644 --- a/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -685,6 +685,19 @@ public: /// Try moving transposes in order to fold them away or into multiplies. void optimizeTransposes() { + auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { + // We need to remove Old from the ShapeMap otherwise RAUW will replace it + // with New. We should only add New it it supportsShapeInfo so we insert + // it conditionally instead. + auto S = ShapeMap.find(&Old); + if (S != ShapeMap.end()) { + ShapeMap.erase(S); + if (supportsShapeInfo(New)) + ShapeMap.insert({New, S->second}); + } + Old.replaceAllUsesWith(New); + }; + // First sink all transposes inside matmuls, hoping that we end up with NN, // NT or TN variants. for (BasicBlock &BB : reverse(Func)) { @@ -717,7 +730,7 @@ public: Value *TATA; if (match(TA, m_Intrinsic(m_Value(TATA)))) { - I.replaceAllUsesWith(TATA); + ReplaceAllUsesWith(I, TATA); EraseFromParent(&I); EraseFromParent(TA); } @@ -740,8 +753,7 @@ public: NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), K->getZExtValue(), R->getZExtValue(), "mmul"); - setShapeInfo(NewInst, {C, R}); - I.replaceAllUsesWith(NewInst); + ReplaceAllUsesWith(I, NewInst); EraseFromParent(&I); EraseFromParent(TA); } @@ -774,8 +786,7 @@ public: setShapeInfo(M, {C, R}); Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(), C->getZExtValue()); - setShapeInfo(NewInst, {C, R}); - I->replaceAllUsesWith(NewInst); + ReplaceAllUsesWith(*I, NewInst); if (I->use_empty()) I->eraseFromParent(); if (A->use_empty()) @@ -879,10 +890,30 @@ public: // Delete the instructions backwards, as it has a reduced likelihood of // having to update as many def-use and use-def chains. + // + // Because we add to ToRemove during fusion we can't guarantee that defs + // are before uses. Change uses to undef temporarily as these should get + // removed as well. + // + // For verification, we keep track of where we changed uses to undefs in + // UndefedInsts and then check that we in fact remove them. + SmallSet UndefedInsts; for (auto *Inst : reverse(ToRemove)) { - if (!Inst->use_empty()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); + for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { + Use &U = *I++; + if (auto *Undefed = dyn_cast(U.getUser())) + UndefedInsts.insert(Undefed); + U.set(UndefValue::get(Inst->getType())); + } Inst->eraseFromParent(); + UndefedInsts.erase(Inst); + } + if (!UndefedInsts.empty()) { + // If we didn't remove all undefed instructions, it's a hard error. + dbgs() << "Undefed but present instructions:\n"; + for (auto *I : UndefedInsts) + dbgs() << *I << "\n"; + llvm_unreachable("Undefed but instruction not removed"); } return Changed; diff --git a/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll index 25c27a960b4..2a7ac8278d5 100644 --- a/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll +++ b/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll @@ -986,6 +986,32 @@ entry: ret <4 x float> %m } +define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) { +; CHECK-LABEL: @transpose_of_transpose_of_non_matrix_op( +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[A:%.*]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[A]], i64 4 +; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[A]], i64 8 +; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[A]], i64 12 +; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x double> [[COL_LOAD5]], <2 x double> [[COL_LOAD8]], <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP2]], <8 x i32> +; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <8 x double> [[TMP3]], <8 x double> poison, <6 x i32> +; CHECK-NEXT: ret <6 x double> [[SHUF]] +; + %load = call <8 x double> @llvm.matrix.column.major.load.v8f64(double* %a, i64 4, i1 false, i32 2, i32 4) + %shuf = shufflevector <8 x double> %load, <8 x double> poison, <6 x i32> + %t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %shuf, i32 3, i32 2) + %tt = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %t, i32 2, i32 3) + ret <6 x double> %tt +} + declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32) declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32) @@ -995,3 +1021,4 @@ declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32) declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32) declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32) declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32) +declare <8 x double> @llvm.matrix.column.major.load.v8f64(double*, i64, i1, i32, i32)