1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2025-01-31 12:41:49 +01:00

[Matrix] Update shape propagation to iterate until done.

This patch updates the shape propagation to iterate until no new shape
information is discovered.

As initial seed for the forward propagation, we use the matrix intrinsic
instructions. Both propagateShapeForward and propagateShapeBackward
return new work lists, with the instructions to be used for the next
iteration. When propagating forward, we record all instructions we added
new shape information for. When propagating backward, we record all
users of instructions we added new shape information for.

Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D70901
This commit is contained in:
Florian Hahn 2020-01-09 10:23:34 +00:00
parent 0a96d7d106
commit 4d6151bec8
2 changed files with 145 additions and 42 deletions

View File

@ -10,9 +10,6 @@
//
// TODO:
// * Implement multiply & add fusion
// * Implement shape propagation
// * Implement optimizations to reduce or eliminateshufflevector uses by using
// shape information.
// * Add remark, summarizing the available matrix optimization opportunities.
//
//===----------------------------------------------------------------------===//
@ -321,32 +318,12 @@ public:
}
/// Propagate the shape information of instructions to their users.
void propagateShapeForward() {
// The work list contains instructions for which we can compute the shape,
// either based on the information provided by matrix intrinsics or known
// shapes of operands.
SmallVector<Instruction *, 8> WorkList;
// Initialize the work list with ops carrying shape information. Initially
// only the shape of matrix intrinsics is known.
for (BasicBlock &BB : Func)
for (Instruction &Inst : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
if (!II)
continue;
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
case Intrinsic::matrix_columnwise_load:
case Intrinsic::matrix_columnwise_store:
WorkList.push_back(&Inst);
break;
default:
break;
}
}
/// The work list contains instructions for which we can compute the shape,
/// either based on the information provided by matrix intrinsics or known
/// shapes of operands.
SmallVector<Instruction *, 32>
propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
SmallVector<Instruction *, 32> NewWorkList;
// Pop an element for which we guaranteed to have at least one of the
// operand shapes. Add the shape for this and then add users to the work
// list.
@ -395,20 +372,29 @@ public:
}
}
if (Propagate)
if (Propagate) {
NewWorkList.push_back(Inst);
for (auto *User : Inst->users())
if (ShapeMap.count(User) == 0)
WorkList.push_back(cast<Instruction>(User));
}
}
return NewWorkList;
}
/// Propagate the shape to operands of instructions with shape information.
void propagateShapeBackward() {
SmallVector<Value *, 8> WorkList;
// Worklist contains instruction for which we already know the shape.
for (auto &V : ShapeMap)
WorkList.push_back(V.first);
/// \p Worklist contains the instruction for which we already know the shape.
SmallVector<Instruction *, 32>
propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
SmallVector<Instruction *, 32> NewWorkList;
auto pushInstruction = [](Value *V,
SmallVectorImpl<Instruction *> &WorkList) {
Instruction *I = dyn_cast<Instruction>(V);
if (I)
WorkList.push_back(I);
};
// Pop an element with known shape. Traverse the operands, if their shape
// derives from the result shape and is unknown, add it and add them to the
// worklist.
@ -417,6 +403,7 @@ public:
Value *V = WorkList.back();
WorkList.pop_back();
size_t BeforeProcessingV = WorkList.size();
if (!isa<Instruction>(V))
continue;
@ -429,21 +416,21 @@ public:
m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
m_Value(N), m_Value(K)))) {
if (setShapeInfo(MatrixA, {M, N}))
WorkList.push_back(MatrixA);
pushInstruction(MatrixA, WorkList);
if (setShapeInfo(MatrixB, {N, K}))
WorkList.push_back(MatrixB);
pushInstruction(MatrixB, WorkList);
} else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
// Flip dimensions.
if (setShapeInfo(MatrixA, {M, N}))
WorkList.push_back(MatrixA);
pushInstruction(MatrixA, WorkList);
} else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
m_Value(MatrixA), m_Value(), m_Value(),
m_Value(M), m_Value(N)))) {
if (setShapeInfo(MatrixA, {M, N})) {
WorkList.push_back(MatrixA);
pushInstruction(MatrixA, WorkList);
}
} else if (isa<LoadInst>(V) ||
match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
@ -456,16 +443,48 @@ public:
ShapeInfo Shape = ShapeMap[V];
for (Use &U : cast<Instruction>(V)->operands()) {
if (setShapeInfo(U.get(), Shape))
WorkList.push_back(U.get());
pushInstruction(U.get(), WorkList);
}
}
// After we discovered new shape info for new instructions in the
// worklist, we use their users as seeds for the next round of forward
// propagation.
for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
for (User *U : WorkList[I]->users())
if (isa<Instruction>(U) && V != U)
NewWorkList.push_back(cast<Instruction>(U));
}
return NewWorkList;
}
bool Visit() {
if (EnableShapePropagation) {
propagateShapeForward();
propagateShapeBackward();
SmallVector<Instruction *, 32> WorkList;
// Initially only the shape of matrix intrinsics is known.
// Initialize the work list with ops carrying shape information.
for (BasicBlock &BB : Func)
for (Instruction &Inst : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
if (!II)
continue;
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
case Intrinsic::matrix_columnwise_load:
case Intrinsic::matrix_columnwise_store:
WorkList.push_back(&Inst);
break;
default:
break;
}
}
// Propagate shapes until nothing changes any longer.
while (!WorkList.empty()) {
WorkList = propagateShapeForward(WorkList);
WorkList = propagateShapeBackward(WorkList);
}
}
ReversePostOrderTraversal<Function *> RPOT(&Func);

View File

@ -0,0 +1,84 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
; Make sure we propagate in multiple iterations. First, we back-propagate the
; shape information from the transpose to %A, in the next iteration we
; forward-propagate it to %Mul, and then back to %B.
define <16 x double> @backpropagation_iterations(<16 x double>* %A.Ptr, <16 x double>* %B.Ptr) {
; CHECK-LABEL: @backpropagation_iterations(
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x double>* [[A_PTR:%.*]] to double*
; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <4 x double>*
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x double>, <4 x double>* [[TMP2]], align 8
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr double, double* [[TMP1]], i32 4
; CHECK-NEXT: [[TMP6:%.*]] = bitcast double* [[TMP5]] to <4 x double>*
; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, <4 x double>* [[TMP6]], align 8
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr double, double* [[TMP1]], i32 8
; CHECK-NEXT: [[TMP10:%.*]] = bitcast double* [[TMP9]] to <4 x double>*
; CHECK-NEXT: [[TMP11:%.*]] = load <4 x double>, <4 x double>* [[TMP10]], align 8
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr double, double* [[TMP1]], i32 12
; CHECK-NEXT: [[TMP14:%.*]] = bitcast double* [[TMP13]] to <4 x double>*
; CHECK-NEXT: [[TMP15:%.*]] = load <4 x double>, <4 x double>* [[TMP14]], align 8
; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> undef, double [[TMP16]], i64 0
; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x double> [[TMP7]], i64 0
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 1
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x double> [[TMP11]], i64 0
; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x double> [[TMP19]], double [[TMP20]], i64 2
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x double> [[TMP15]], i64 0
; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x double> [[TMP21]], double [[TMP22]], i64 3
; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
; CHECK-NEXT: [[TMP25:%.*]] = insertelement <4 x double> undef, double [[TMP24]], i64 0
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x double> [[TMP7]], i64 1
; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x double> [[TMP25]], double [[TMP26]], i64 1
; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x double> [[TMP11]], i64 1
; CHECK-NEXT: [[TMP29:%.*]] = insertelement <4 x double> [[TMP27]], double [[TMP28]], i64 2
; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x double> [[TMP15]], i64 1
; CHECK-NEXT: [[TMP31:%.*]] = insertelement <4 x double> [[TMP29]], double [[TMP30]], i64 3
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x double> undef, double [[TMP32]], i64 0
; CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x double> [[TMP7]], i64 2
; CHECK-NEXT: [[TMP35:%.*]] = insertelement <4 x double> [[TMP33]], double [[TMP34]], i64 1
; CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x double> [[TMP11]], i64 2
; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x double> [[TMP35]], double [[TMP36]], i64 2
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x double> [[TMP15]], i64 2
; CHECK-NEXT: [[TMP39:%.*]] = insertelement <4 x double> [[TMP37]], double [[TMP38]], i64 3
; CHECK-NEXT: [[TMP40:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
; CHECK-NEXT: [[TMP41:%.*]] = insertelement <4 x double> undef, double [[TMP40]], i64 0
; CHECK-NEXT: [[TMP42:%.*]] = extractelement <4 x double> [[TMP7]], i64 3
; CHECK-NEXT: [[TMP43:%.*]] = insertelement <4 x double> [[TMP41]], double [[TMP42]], i64 1
; CHECK-NEXT: [[TMP44:%.*]] = extractelement <4 x double> [[TMP11]], i64 3
; CHECK-NEXT: [[TMP45:%.*]] = insertelement <4 x double> [[TMP43]], double [[TMP44]], i64 2
; CHECK-NEXT: [[TMP46:%.*]] = extractelement <4 x double> [[TMP15]], i64 3
; CHECK-NEXT: [[TMP47:%.*]] = insertelement <4 x double> [[TMP45]], double [[TMP46]], i64 3
; CHECK-NEXT: [[TMP48:%.*]] = bitcast <16 x double>* [[B_PTR:%.*]] to double*
; CHECK-NEXT: [[TMP49:%.*]] = bitcast double* [[TMP48]] to <4 x double>*
; CHECK-NEXT: [[TMP50:%.*]] = load <4 x double>, <4 x double>* [[TMP49]], align 8
; CHECK-NEXT: [[TMP52:%.*]] = getelementptr double, double* [[TMP48]], i32 4
; CHECK-NEXT: [[TMP53:%.*]] = bitcast double* [[TMP52]] to <4 x double>*
; CHECK-NEXT: [[TMP54:%.*]] = load <4 x double>, <4 x double>* [[TMP53]], align 8
; CHECK-NEXT: [[TMP56:%.*]] = getelementptr double, double* [[TMP48]], i32 8
; CHECK-NEXT: [[TMP57:%.*]] = bitcast double* [[TMP56]] to <4 x double>*
; CHECK-NEXT: [[TMP58:%.*]] = load <4 x double>, <4 x double>* [[TMP57]], align 8
; CHECK-NEXT: [[TMP60:%.*]] = getelementptr double, double* [[TMP48]], i32 12
; CHECK-NEXT: [[TMP61:%.*]] = bitcast double* [[TMP60]] to <4 x double>*
; CHECK-NEXT: [[TMP62:%.*]] = load <4 x double>, <4 x double>* [[TMP61]], align 8
; CHECK-NEXT: [[TMP63:%.*]] = fmul <4 x double> [[TMP3]], [[TMP50]]
; CHECK-NEXT: [[TMP64:%.*]] = fmul <4 x double> [[TMP7]], [[TMP54]]
; CHECK-NEXT: [[TMP65:%.*]] = fmul <4 x double> [[TMP11]], [[TMP58]]
; CHECK-NEXT: [[TMP66:%.*]] = fmul <4 x double> [[TMP15]], [[TMP62]]
; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <4 x double> [[TMP63]], <4 x double> [[TMP64]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <4 x double> [[TMP65]], <4 x double> [[TMP66]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <8 x double> [[TMP67]], <8 x double> [[TMP68]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: ret <16 x double> [[TMP69]]
;
%A = load <16 x double>, <16 x double>* %A.Ptr
%A.trans = tail call <16 x double> @llvm.matrix.transpose.v16f64(<16 x double> %A, i32 4, i32 4)
%B = load <16 x double>, <16 x double>* %B.Ptr
%Mul = fmul <16 x double> %A, %B
ret <16 x double> %Mul
}
declare <16 x double> @llvm.matrix.multiply.v16f64.v16f64.v16f64(<16 x double>, <16 x double>, i32 immarg, i32 immarg, i32 immarg)
declare <16 x double> @llvm.matrix.transpose.v16f64(<16 x double>, i32 immarg, i32 immarg)