mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2025-01-31 12:41:49 +01:00
[Matrix] Propagate and use shape info for binary operators.
This patch extends the current shape propagation and shape aware lowering to also support binary operators. Those operators are uniform with respect to their shape (shape of the input operands is the same as the shape of their result). Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor Reviewed By: anemet Differential Revision: https://reviews.llvm.org/D70898
This commit is contained in:
parent
d78215b636
commit
8be26dcc2f
@ -281,6 +281,24 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isUniformShape(Value *V) {
|
||||
Instruction *I = dyn_cast<Instruction>(V);
|
||||
if (!I)
|
||||
return true;
|
||||
|
||||
switch (I->getOpcode()) {
|
||||
case Instruction::FAdd:
|
||||
case Instruction::FSub:
|
||||
case Instruction::FMul: // Scalar multiply.
|
||||
case Instruction::Add:
|
||||
case Instruction::Mul:
|
||||
case Instruction::Sub:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if shape information can be used for \p V. The supported
|
||||
/// instructions must match the instructions that can be lowered by this pass.
|
||||
bool supportsShapeInfo(Value *V) {
|
||||
@ -299,7 +317,7 @@ public:
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return isa<StoreInst>(Inst);
|
||||
return isUniformShape(V) || isa<StoreInst>(V);
|
||||
}
|
||||
|
||||
/// Propagate the shape information of instructions to their users.
|
||||
@ -366,6 +384,15 @@ public:
|
||||
if (OpShape != ShapeMap.end())
|
||||
setShapeInfo(Inst, OpShape->second);
|
||||
continue;
|
||||
} else if (isUniformShape(Inst)) {
|
||||
// Find the first operand that has a known shape and use that.
|
||||
for (auto &Op : Inst->operands()) {
|
||||
auto OpShape = ShapeMap.find(Op.get());
|
||||
if (OpShape != ShapeMap.end()) {
|
||||
Propagate |= setShapeInfo(Inst, OpShape->second);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Propagate)
|
||||
@ -390,7 +417,9 @@ public:
|
||||
|
||||
Value *Op1;
|
||||
Value *Op2;
|
||||
if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
|
||||
if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
|
||||
Changed |= VisitBinaryOperator(BinOp);
|
||||
else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
|
||||
Changed |= VisitStore(&Inst, Op1, Op2, Builder);
|
||||
}
|
||||
}
|
||||
@ -673,6 +702,49 @@ public:
|
||||
LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Lower binary operators, if shape information is available.
|
||||
bool VisitBinaryOperator(BinaryOperator *Inst) {
|
||||
auto I = ShapeMap.find(Inst);
|
||||
if (I == ShapeMap.end())
|
||||
return false;
|
||||
|
||||
Value *Lhs = Inst->getOperand(0);
|
||||
Value *Rhs = Inst->getOperand(1);
|
||||
|
||||
IRBuilder<> Builder(Inst);
|
||||
ShapeInfo &Shape = I->second;
|
||||
|
||||
ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
|
||||
ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
|
||||
|
||||
// Add each column and store the result back into the opmapping
|
||||
ColumnMatrixTy Result;
|
||||
auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
|
||||
switch (Inst->getOpcode()) {
|
||||
case Instruction::Add:
|
||||
return Builder.CreateAdd(LHS, RHS);
|
||||
case Instruction::Mul:
|
||||
return Builder.CreateMul(LHS, RHS);
|
||||
case Instruction::Sub:
|
||||
return Builder.CreateSub(LHS, RHS);
|
||||
case Instruction::FAdd:
|
||||
return Builder.CreateFAdd(LHS, RHS);
|
||||
case Instruction::FMul:
|
||||
return Builder.CreateFMul(LHS, RHS);
|
||||
case Instruction::FSub:
|
||||
return Builder.CreateFSub(LHS, RHS);
|
||||
default:
|
||||
llvm_unreachable("Unsupported binary operator for matrix");
|
||||
}
|
||||
};
|
||||
for (unsigned C = 0; C < Shape.NumColumns; ++C)
|
||||
Result.addColumn(
|
||||
BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
|
||||
|
||||
finalizeLowering(Inst, Result, Builder);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -462,15 +462,34 @@ define void @transpose_multiply_add(<9 x double>* %A.Ptr, <9 x double>* %B.Ptr,
|
||||
|
||||
; CHECK-NEXT: [[TMP106:%.*]] = shufflevector <1 x double> [[TMP105]], <1 x double> undef, <3 x i32> <i32 0, i32 undef, i32 undef>
|
||||
; CHECK-NEXT: [[TMP107:%.*]] = shufflevector <3 x double> [[TMP97]], <3 x double> [[TMP106]], <3 x i32> <i32 0, i32 1, i32 3>
|
||||
; CHECK-NEXT: [[TMP108:%.*]] = shufflevector <3 x double> [[TMP47]], <3 x double> [[TMP77]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
|
||||
; CHECK-NEXT: [[TMP109:%.*]] = shufflevector <3 x double> [[TMP107]], <3 x double> undef, <6 x i32> <i32 0, i32 1, i32 2, i32 undef, i32 undef, i32 undef>
|
||||
; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <6 x double> [[TMP108]], <6 x double> [[TMP109]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
|
||||
|
||||
; Load %C and add result of multiply.
|
||||
; Load %C.
|
||||
|
||||
; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[C_PTR:%.*]]
|
||||
; CHECK-NEXT: [[RES:%.*]] = fadd <9 x double> [[C]], [[TMP110]]
|
||||
; CHECK-NEXT: store <9 x double> [[RES]], <9 x double>* [[C_PTR]]
|
||||
|
||||
; Extract columns from %C.
|
||||
|
||||
; CHECK-NEXT: [[SPLIT84:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 0, i32 1, i32 2>
|
||||
; CHECK-NEXT: [[SPLIT85:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 3, i32 4, i32 5>
|
||||
; CHECK-NEXT: [[SPLIT86:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 6, i32 7, i32 8>
|
||||
|
||||
; Add column vectors.
|
||||
|
||||
; CHECK-NEXT: [[TMP108:%.*]] = fadd <3 x double> [[SPLIT84]], [[TMP47]]
|
||||
; CHECK-NEXT: [[TMP109:%.*]] = fadd <3 x double> [[SPLIT85]], [[TMP77]]
|
||||
; CHECK-NEXT: [[TMP110:%.*]] = fadd <3 x double> [[SPLIT86]], [[TMP107]]
|
||||
|
||||
; Store result columns.
|
||||
|
||||
; CHECK-NEXT: [[TMP111:%.*]] = bitcast <9 x double>* [[C_PTR]] to double*
|
||||
; CHECK-NEXT: [[TMP112:%.*]] = bitcast double* [[TMP111]] to <3 x double>*
|
||||
; CHECK-NEXT: store <3 x double> [[TMP108]], <3 x double>* [[TMP112]], align 8
|
||||
; CHECK-NEXT: [[TMP113:%.*]] = getelementptr double, double* [[TMP111]], i32 3
|
||||
; CHECK-NEXT: [[TMP114:%.*]] = bitcast double* [[TMP113]] to <3 x double>*
|
||||
; CHECK-NEXT: store <3 x double> [[TMP109]], <3 x double>* [[TMP114]], align 8
|
||||
; CHECK-NEXT: [[TMP115:%.*]] = getelementptr double, double* [[TMP111]], i32 6
|
||||
; CHECK-NEXT: [[TMP116:%.*]] = bitcast double* [[TMP115]] to <3 x double>*
|
||||
; CHECK-NEXT: store <3 x double> [[TMP110]], <3 x double>* [[TMP116]], align 8
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
|
@ -42,3 +42,75 @@ entry:
|
||||
}
|
||||
|
||||
declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32)
|
||||
|
||||
define <8 x double> @transpose_fadd(<8 x double> %a) {
|
||||
; CHECK-LABEL: @transpose_fadd(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
|
||||
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
|
||||
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
|
||||
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
|
||||
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
|
||||
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
|
||||
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
|
||||
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
|
||||
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
|
||||
; CHECK-NEXT: [[TMP16:%.*]] = fadd <4 x double> [[TMP7]], [[SPLIT4]]
|
||||
; CHECK-NEXT: [[TMP17:%.*]] = fadd <4 x double> [[TMP15]], [[SPLIT5]]
|
||||
; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
|
||||
; CHECK-NEXT: ret <8 x double> [[TMP18]]
|
||||
;
|
||||
entry:
|
||||
%c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
|
||||
%res = fadd <8 x double> %c, %a
|
||||
ret <8 x double> %res
|
||||
}
|
||||
|
||||
define <8 x double> @transpose_fmul(<8 x double> %a) {
|
||||
; CHECK-LABEL: @transpose_fmul(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
|
||||
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
|
||||
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
|
||||
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
|
||||
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
|
||||
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
|
||||
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
|
||||
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
|
||||
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
|
||||
; CHECK-NEXT: [[TMP16:%.*]] = fmul <4 x double> [[TMP7]], [[SPLIT4]]
|
||||
; CHECK-NEXT: [[TMP17:%.*]] = fmul <4 x double> [[TMP15]], [[SPLIT5]]
|
||||
; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
|
||||
; CHECK-NEXT: ret <8 x double> [[TMP18]]
|
||||
;
|
||||
entry:
|
||||
%c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
|
||||
%res = fmul <8 x double> %c, %a
|
||||
ret <8 x double> %res
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user