mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 02:33:06 +01:00
[Matrix] Fix shape for factored transpose
The shape of the input is C x R. Differential Revision: https://reviews.llvm.org/D106722
This commit is contained in:
parent
add64be20a
commit
d9613eb43c
@ -774,6 +774,7 @@ public:
|
|||||||
++II;
|
++II;
|
||||||
Value *A, *B, *AT, *BT;
|
Value *A, *B, *AT, *BT;
|
||||||
ConstantInt *R, *K, *C;
|
ConstantInt *R, *K, *C;
|
||||||
|
// A^t * B ^t -> (B * A)^t
|
||||||
if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
|
if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
|
||||||
m_Value(A), m_Value(B), m_ConstantInt(R),
|
m_Value(A), m_Value(B), m_ConstantInt(R),
|
||||||
m_ConstantInt(K), m_ConstantInt(C))) &&
|
m_ConstantInt(K), m_ConstantInt(C))) &&
|
||||||
@ -784,8 +785,8 @@ public:
|
|||||||
Value *M = Builder.CreateMatrixMultiply(
|
Value *M = Builder.CreateMatrixMultiply(
|
||||||
BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
|
BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
|
||||||
setShapeInfo(M, {C, R});
|
setShapeInfo(M, {C, R});
|
||||||
Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
|
Instruction *NewInst = Builder.CreateMatrixTranspose(
|
||||||
C->getZExtValue());
|
M, C->getZExtValue(), R->getZExtValue());
|
||||||
ReplaceAllUsesWith(*I, NewInst);
|
ReplaceAllUsesWith(*I, NewInst);
|
||||||
if (I->use_empty())
|
if (I->use_empty())
|
||||||
I->eraseFromParent();
|
I->eraseFromParent();
|
||||||
|
@ -1012,6 +1012,126 @@ define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
|
|||||||
ret <6 x double> %tt
|
ret <6 x double> %tt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
define <12 x double> @factor_transpose(<6 x double> %a, <8 x double> %b) {
|
||||||
|
; CHECK-LABEL: @factor_transpose(
|
||||||
|
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
|
||||||
|
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x double> [[A:%.*]], <6 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x double> poison, double [[TMP1]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP2:%.*]] = fmul <2 x double> [[BLOCK]], [[SPLAT_SPLAT]]
|
||||||
|
; CHECK-NEXT: [[BLOCK5:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x double> poison, double [[TMP3]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT6]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[BLOCK5]], [[SPLAT_SPLAT7]]
|
||||||
|
; CHECK-NEXT: [[TMP5:%.*]] = fadd <2 x double> [[TMP2]], [[TMP4]]
|
||||||
|
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP6]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[BLOCK8:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x double> poison, double [[TMP8]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT9]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[BLOCK8]], [[SPLAT_SPLAT10]]
|
||||||
|
; CHECK-NEXT: [[BLOCK11:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT12]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP11:%.*]] = fmul <2 x double> [[BLOCK11]], [[SPLAT_SPLAT13]]
|
||||||
|
; CHECK-NEXT: [[TMP12:%.*]] = fadd <2 x double> [[TMP9]], [[TMP11]]
|
||||||
|
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP12]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[BLOCK14:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <2 x double> poison, double [[TMP15]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT15]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP16:%.*]] = fmul <2 x double> [[BLOCK14]], [[SPLAT_SPLAT16]]
|
||||||
|
; CHECK-NEXT: [[BLOCK17:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP17:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT18]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x double> [[BLOCK17]], [[SPLAT_SPLAT19]]
|
||||||
|
; CHECK-NEXT: [[TMP19:%.*]] = fadd <2 x double> [[TMP16]], [[TMP18]]
|
||||||
|
; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> [[TMP19]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP20]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[BLOCK20:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT21:%.*]] = insertelement <2 x double> poison, double [[TMP22]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT21]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP23:%.*]] = fmul <2 x double> [[BLOCK20]], [[SPLAT_SPLAT22]]
|
||||||
|
; CHECK-NEXT: [[BLOCK23:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP24:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT24:%.*]] = insertelement <2 x double> poison, double [[TMP24]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT24]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP25:%.*]] = fmul <2 x double> [[BLOCK23]], [[SPLAT_SPLAT25]]
|
||||||
|
; CHECK-NEXT: [[TMP26:%.*]] = fadd <2 x double> [[TMP23]], [[TMP25]]
|
||||||
|
; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x double> [[TMP26]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP27]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[BLOCK26:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP29:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT27:%.*]] = insertelement <2 x double> poison, double [[TMP29]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT27]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP30:%.*]] = fmul <2 x double> [[BLOCK26]], [[SPLAT_SPLAT28]]
|
||||||
|
; CHECK-NEXT: [[BLOCK29:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
|
||||||
|
; CHECK-NEXT: [[TMP31:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT30:%.*]] = insertelement <2 x double> poison, double [[TMP31]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT30]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP32:%.*]] = fmul <2 x double> [[BLOCK29]], [[SPLAT_SPLAT31]]
|
||||||
|
; CHECK-NEXT: [[TMP33:%.*]] = fadd <2 x double> [[TMP30]], [[TMP32]]
|
||||||
|
; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <2 x double> [[TMP33]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP34]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[BLOCK32:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP36:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT33:%.*]] = insertelement <2 x double> poison, double [[TMP36]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT33]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP37:%.*]] = fmul <2 x double> [[BLOCK32]], [[SPLAT_SPLAT34]]
|
||||||
|
; CHECK-NEXT: [[BLOCK35:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
|
||||||
|
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLATINSERT36:%.*]] = insertelement <2 x double> poison, double [[TMP38]], i32 0
|
||||||
|
; CHECK-NEXT: [[SPLAT_SPLAT37:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT36]], <2 x double> poison, <2 x i32> zeroinitializer
|
||||||
|
; CHECK-NEXT: [[TMP39:%.*]] = fmul <2 x double> [[BLOCK35]], [[SPLAT_SPLAT37]]
|
||||||
|
; CHECK-NEXT: [[TMP40:%.*]] = fadd <2 x double> [[TMP37]], [[TMP39]]
|
||||||
|
; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <2 x double> [[TMP40]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
|
||||||
|
; CHECK-NEXT: [[TMP42:%.*]] = shufflevector <4 x double> [[TMP35]], <4 x double> [[TMP41]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[TMP43:%.*]] = extractelement <4 x double> [[TMP14]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP44:%.*]] = insertelement <3 x double> undef, double [[TMP43]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP45:%.*]] = extractelement <4 x double> [[TMP28]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP46:%.*]] = insertelement <3 x double> [[TMP44]], double [[TMP45]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP47:%.*]] = extractelement <4 x double> [[TMP42]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP48:%.*]] = insertelement <3 x double> [[TMP46]], double [[TMP47]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP49:%.*]] = extractelement <4 x double> [[TMP14]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP50:%.*]] = insertelement <3 x double> undef, double [[TMP49]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP51:%.*]] = extractelement <4 x double> [[TMP28]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP52:%.*]] = insertelement <3 x double> [[TMP50]], double [[TMP51]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP53:%.*]] = extractelement <4 x double> [[TMP42]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP54:%.*]] = insertelement <3 x double> [[TMP52]], double [[TMP53]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP55:%.*]] = extractelement <4 x double> [[TMP14]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP56:%.*]] = insertelement <3 x double> undef, double [[TMP55]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP57:%.*]] = extractelement <4 x double> [[TMP28]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP58:%.*]] = insertelement <3 x double> [[TMP56]], double [[TMP57]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP59:%.*]] = extractelement <4 x double> [[TMP42]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP60:%.*]] = insertelement <3 x double> [[TMP58]], double [[TMP59]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP61:%.*]] = extractelement <4 x double> [[TMP14]], i64 3
|
||||||
|
; CHECK-NEXT: [[TMP62:%.*]] = insertelement <3 x double> undef, double [[TMP61]], i64 0
|
||||||
|
; CHECK-NEXT: [[TMP63:%.*]] = extractelement <4 x double> [[TMP28]], i64 3
|
||||||
|
; CHECK-NEXT: [[TMP64:%.*]] = insertelement <3 x double> [[TMP62]], double [[TMP63]], i64 1
|
||||||
|
; CHECK-NEXT: [[TMP65:%.*]] = extractelement <4 x double> [[TMP42]], i64 3
|
||||||
|
; CHECK-NEXT: [[TMP66:%.*]] = insertelement <3 x double> [[TMP64]], double [[TMP65]], i64 2
|
||||||
|
; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <3 x double> [[TMP48]], <3 x double> [[TMP54]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <3 x double> [[TMP60]], <3 x double> [[TMP66]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
|
||||||
|
; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <6 x double> [[TMP67]], <6 x double> [[TMP68]], <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
|
||||||
|
; CHECK-NEXT: ret <12 x double> [[TMP69]]
|
||||||
|
;
|
||||||
|
%at = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %a, i32 2, i32 3)
|
||||||
|
%bt = call <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double> %b, i32 4, i32 2)
|
||||||
|
%m = call <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double> %at, <8 x double> %bt, i32 3, i32 2, i32 4)
|
||||||
|
ret <12 x double> %m
|
||||||
|
}
|
||||||
|
|
||||||
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
|
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
|
||||||
declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
|
declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
|
||||||
declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)
|
declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)
|
||||||
|
Loading…
Reference in New Issue
Block a user