From d9613eb43c81b17bb6811064b97b0f3d667d5a2c Mon Sep 17 00:00:00 2001 From: Adam Nemet Date: Fri, 23 Jul 2021 15:02:44 -0700 Subject: [PATCH] [Matrix] Fix shape for factored transpose The shape of the input is C x R. Differential Revision: https://reviews.llvm.org/D106722 --- .../Scalar/LowerMatrixIntrinsics.cpp | 5 +- .../LowerMatrixIntrinsics/transpose-opts.ll | 120 ++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index ab75cd3f566..42c183a6408 100644 --- a/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -774,6 +774,7 @@ public: ++II; Value *A, *B, *AT, *BT; ConstantInt *R, *K, *C; + // A^t * B ^t -> (B * A)^t if (match(&*I, m_Intrinsic( m_Value(A), m_Value(B), m_ConstantInt(R), m_ConstantInt(K), m_ConstantInt(C))) && @@ -784,8 +785,8 @@ public: Value *M = Builder.CreateMatrixMultiply( BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); setShapeInfo(M, {C, R}); - Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(), - C->getZExtValue()); + Instruction *NewInst = Builder.CreateMatrixTranspose( + M, C->getZExtValue(), R->getZExtValue()); ReplaceAllUsesWith(*I, NewInst); if (I->use_empty()) I->eraseFromParent(); diff --git a/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll index 2a7ac8278d5..d7c89ae5029 100644 --- a/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll +++ b/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll @@ -1012,6 +1012,126 @@ define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) { ret <6 x double> %tt } +define <12 x double> @factor_transpose(<6 x double> %a, <8 x double> %b) { +; CHECK-LABEL: @factor_transpose( +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x double> [[A:%.*]], <6 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x double> poison, double [[TMP1]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = fmul <2 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK5:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x double> poison, double [[TMP3]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT6]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[BLOCK5]], [[SPLAT_SPLAT7]] +; CHECK-NEXT: [[TMP5:%.*]] = fadd <2 x double> [[TMP2]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP6]], <4 x i32> +; CHECK-NEXT: [[BLOCK8:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x double> poison, double [[TMP8]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT9]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[BLOCK8]], [[SPLAT_SPLAT10]] +; CHECK-NEXT: [[BLOCK11:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT12]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP11:%.*]] = fmul <2 x double> [[BLOCK11]], [[SPLAT_SPLAT13]] +; CHECK-NEXT: [[TMP12:%.*]] = fadd <2 x double> [[TMP9]], [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP12]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP13]], <4 x i32> +; CHECK-NEXT: [[BLOCK14:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <2 x double> poison, double [[TMP15]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT15]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = fmul <2 x double> [[BLOCK14]], [[SPLAT_SPLAT16]] +; CHECK-NEXT: [[BLOCK17:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT18]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x double> [[BLOCK17]], [[SPLAT_SPLAT19]] +; CHECK-NEXT: [[TMP19:%.*]] = fadd <2 x double> [[TMP16]], [[TMP18]] +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> [[TMP19]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP20]], <4 x i32> +; CHECK-NEXT: [[BLOCK20:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT21:%.*]] = insertelement <2 x double> poison, double [[TMP22]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT21]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP23:%.*]] = fmul <2 x double> [[BLOCK20]], [[SPLAT_SPLAT22]] +; CHECK-NEXT: [[BLOCK23:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT24:%.*]] = insertelement <2 x double> poison, double [[TMP24]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT24]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP25:%.*]] = fmul <2 x double> [[BLOCK23]], [[SPLAT_SPLAT25]] +; CHECK-NEXT: [[TMP26:%.*]] = fadd <2 x double> [[TMP23]], [[TMP25]] +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x double> [[TMP26]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP27]], <4 x i32> +; CHECK-NEXT: [[BLOCK26:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT27:%.*]] = insertelement <2 x double> poison, double [[TMP29]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT27]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP30:%.*]] = fmul <2 x double> [[BLOCK26]], [[SPLAT_SPLAT28]] +; CHECK-NEXT: [[BLOCK29:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP31:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT30:%.*]] = insertelement <2 x double> poison, double [[TMP31]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT30]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP32:%.*]] = fmul <2 x double> [[BLOCK29]], [[SPLAT_SPLAT31]] +; CHECK-NEXT: [[TMP33:%.*]] = fadd <2 x double> [[TMP30]], [[TMP32]] +; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <2 x double> [[TMP33]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP34]], <4 x i32> +; CHECK-NEXT: [[BLOCK32:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP36:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT33:%.*]] = insertelement <2 x double> poison, double [[TMP36]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT33]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP37:%.*]] = fmul <2 x double> [[BLOCK32]], [[SPLAT_SPLAT34]] +; CHECK-NEXT: [[BLOCK35:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP38:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT36:%.*]] = insertelement <2 x double> poison, double [[TMP38]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT37:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT36]], <2 x double> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP39:%.*]] = fmul <2 x double> [[BLOCK35]], [[SPLAT_SPLAT37]] +; CHECK-NEXT: [[TMP40:%.*]] = fadd <2 x double> [[TMP37]], [[TMP39]] +; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <2 x double> [[TMP40]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP42:%.*]] = shufflevector <4 x double> [[TMP35]], <4 x double> [[TMP41]], <4 x i32> +; CHECK-NEXT: [[TMP43:%.*]] = extractelement <4 x double> [[TMP14]], i64 0 +; CHECK-NEXT: [[TMP44:%.*]] = insertelement <3 x double> undef, double [[TMP43]], i64 0 +; CHECK-NEXT: [[TMP45:%.*]] = extractelement <4 x double> [[TMP28]], i64 0 +; CHECK-NEXT: [[TMP46:%.*]] = insertelement <3 x double> [[TMP44]], double [[TMP45]], i64 1 +; CHECK-NEXT: [[TMP47:%.*]] = extractelement <4 x double> [[TMP42]], i64 0 +; CHECK-NEXT: [[TMP48:%.*]] = insertelement <3 x double> [[TMP46]], double [[TMP47]], i64 2 +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <4 x double> [[TMP14]], i64 1 +; CHECK-NEXT: [[TMP50:%.*]] = insertelement <3 x double> undef, double [[TMP49]], i64 0 +; CHECK-NEXT: [[TMP51:%.*]] = extractelement <4 x double> [[TMP28]], i64 1 +; CHECK-NEXT: [[TMP52:%.*]] = insertelement <3 x double> [[TMP50]], double [[TMP51]], i64 1 +; CHECK-NEXT: [[TMP53:%.*]] = extractelement <4 x double> [[TMP42]], i64 1 +; CHECK-NEXT: [[TMP54:%.*]] = insertelement <3 x double> [[TMP52]], double [[TMP53]], i64 2 +; CHECK-NEXT: [[TMP55:%.*]] = extractelement <4 x double> [[TMP14]], i64 2 +; CHECK-NEXT: [[TMP56:%.*]] = insertelement <3 x double> undef, double [[TMP55]], i64 0 +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <4 x double> [[TMP28]], i64 2 +; CHECK-NEXT: [[TMP58:%.*]] = insertelement <3 x double> [[TMP56]], double [[TMP57]], i64 1 +; CHECK-NEXT: [[TMP59:%.*]] = extractelement <4 x double> [[TMP42]], i64 2 +; CHECK-NEXT: [[TMP60:%.*]] = insertelement <3 x double> [[TMP58]], double [[TMP59]], i64 2 +; CHECK-NEXT: [[TMP61:%.*]] = extractelement <4 x double> [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP62:%.*]] = insertelement <3 x double> undef, double [[TMP61]], i64 0 +; CHECK-NEXT: [[TMP63:%.*]] = extractelement <4 x double> [[TMP28]], i64 3 +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <3 x double> [[TMP62]], double [[TMP63]], i64 1 +; CHECK-NEXT: [[TMP65:%.*]] = extractelement <4 x double> [[TMP42]], i64 3 +; CHECK-NEXT: [[TMP66:%.*]] = insertelement <3 x double> [[TMP64]], double [[TMP65]], i64 2 +; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <3 x double> [[TMP48]], <3 x double> [[TMP54]], <6 x i32> +; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <3 x double> [[TMP60]], <3 x double> [[TMP66]], <6 x i32> +; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <6 x double> [[TMP67]], <6 x double> [[TMP68]], <12 x i32> +; CHECK-NEXT: ret <12 x double> [[TMP69]] +; + %at = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %a, i32 2, i32 3) + %bt = call <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double> %b, i32 4, i32 2) + %m = call <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double> %at, <8 x double> %bt, i32 3, i32 2, i32 4) + ret <12 x double> %m +} + declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32) declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)