mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-10-19 19:12:56 +02:00
[Matrix] Add TileInfo abstraction for tiled matrix code-gen.
This patch adds a TileInfo abstraction and utilities to create a 3-level loop nest for tiling. Reviewers: anemet Reviewed By: anemet Differential Revision: https://reviews.llvm.org/D77550
This commit is contained in:
parent
5b0daad836
commit
992d1824c2
94
include/llvm/Transforms/Utils/MatrixUtils.h
Normal file
94
include/llvm/Transforms/Utils/MatrixUtils.h
Normal file
@ -0,0 +1,94 @@
|
||||
//===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Utilities for generating tiled loops for matrix operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
|
||||
#define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
class DomTreeUpdater;
|
||||
class BasicBlock;
|
||||
class Value;
|
||||
class Loop;
|
||||
class LoopInfo;
|
||||
class IRBuilderBase;
|
||||
|
||||
/// A helper struct to create IR loop nests for tiling in IR of the following
|
||||
/// form:
|
||||
/// for CurrentColumn = 0..NumColumns
|
||||
/// for CurrentRow = 0..NumRows
|
||||
/// for CurrentInner = 0..NumInner
|
||||
struct TileInfo {
|
||||
/// Number of rows of the matrix.
|
||||
unsigned NumRows;
|
||||
|
||||
/// Number of columns of the matrix.
|
||||
unsigned NumColumns;
|
||||
|
||||
/// Number of columns of the first matrix of a multiply /
|
||||
/// number of rows of the second matrix of a multiply.
|
||||
unsigned NumInner;
|
||||
|
||||
/// Number of rows/columns in a tile.
|
||||
unsigned TileSize = -1;
|
||||
|
||||
/// Start row of the current tile to compute.
|
||||
Value *CurrentRow;
|
||||
|
||||
/// Start column of the current tile to compute.
|
||||
Value *CurrentCol;
|
||||
|
||||
/// Current tile offset during the tile computation.
|
||||
Value *CurrentK;
|
||||
|
||||
/// Header of the outermost loop iterating from 0..NumColumns.
|
||||
BasicBlock *ColumnLoopHeader = nullptr;
|
||||
|
||||
/// Header of the second loop iterating from 0..NumRows.
|
||||
BasicBlock *RowLoopHeader = nullptr;
|
||||
/// Latch of the second loop iterating from 0..NumRows.
|
||||
BasicBlock *RowLoopLatch = nullptr;
|
||||
/// Header of the innermost loop iterating from 0..NumInner.
|
||||
BasicBlock *InnerLoopHeader = nullptr;
|
||||
/// Latch of the innermost loop iterating from 0..NumInner.
|
||||
BasicBlock *InnerLoopLatch = nullptr;
|
||||
|
||||
TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
|
||||
unsigned TileSize)
|
||||
: NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
|
||||
TileSize(TileSize) {}
|
||||
|
||||
/// Creates an IR loop nests for tiling of the form below. Returns the block
|
||||
/// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
|
||||
/// fields.
|
||||
///
|
||||
/// for CurrentColumn = 0..NumColumns
|
||||
/// for CurrentRow = 0..NumRows
|
||||
/// for CurrentInner = 0..NumInner
|
||||
BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU,
|
||||
LoopInfo &LI);
|
||||
|
||||
private:
|
||||
/// Creates a new loop with header, body and latch blocks that iterates from
|
||||
/// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
|
||||
/// Exit as exit block. Adds the new loop blocks to \L and applies dominator
|
||||
/// tree updates to \p DTU.
|
||||
static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
|
||||
Value *Bound, Value *Step, StringRef Name,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
|
||||
LoopInfo &LI);
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
@ -46,6 +46,7 @@ add_llvm_component_library(LLVMTransformUtils
|
||||
LowerInvoke.cpp
|
||||
LowerMemIntrinsics.cpp
|
||||
LowerSwitch.cpp
|
||||
MatrixUtils.cpp
|
||||
Mem2Reg.cpp
|
||||
MetaRenamer.cpp
|
||||
MisExpect.cpp
|
||||
|
104
lib/Transforms/Utils/MatrixUtils.cpp
Normal file
104
lib/Transforms/Utils/MatrixUtils.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Utilities for generating tiled loops for matrix operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/Transforms/Utils/MatrixUtils.h"
|
||||
#include "llvm/Analysis/DomTreeUpdater.h"
|
||||
#include "llvm/Analysis/LoopInfo.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Dominators.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
|
||||
Value *Bound, Value *Step, StringRef Name,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
|
||||
LoopInfo &LI) {
|
||||
LLVMContext &Ctx = Preheader->getContext();
|
||||
BasicBlock *Header = BasicBlock::Create(
|
||||
Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
|
||||
BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
|
||||
Header->getParent(), Exit);
|
||||
BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
|
||||
Header->getParent(), Exit);
|
||||
|
||||
Type *I32Ty = Type::getInt64Ty(Ctx);
|
||||
BranchInst::Create(Body, Header);
|
||||
BranchInst::Create(Latch, Body);
|
||||
PHINode *IV =
|
||||
PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
|
||||
IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
|
||||
|
||||
B.SetInsertPoint(Latch);
|
||||
Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
|
||||
Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
|
||||
BranchInst::Create(Header, Exit, Cond, Latch);
|
||||
IV->addIncoming(Inc, Latch);
|
||||
|
||||
BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
|
||||
BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
|
||||
PreheaderBr->setSuccessor(0, Header);
|
||||
DTU.applyUpdatesPermissive({
|
||||
{DominatorTree::Delete, Preheader, Tmp},
|
||||
{DominatorTree::Insert, Header, Body},
|
||||
{DominatorTree::Insert, Body, Latch},
|
||||
{DominatorTree::Insert, Latch, Header},
|
||||
{DominatorTree::Insert, Latch, Exit},
|
||||
{DominatorTree::Insert, Preheader, Header},
|
||||
});
|
||||
|
||||
L->addBasicBlockToLoop(Header, LI);
|
||||
L->addBasicBlockToLoop(Body, LI);
|
||||
L->addBasicBlockToLoop(Latch, LI);
|
||||
return Body;
|
||||
}
|
||||
|
||||
// Creates the following loop nest skeleton:
|
||||
// for C = 0; C < NumColumns; C += TileSize
|
||||
// for R = 0; R < NumRows; R += TileSize
|
||||
// for K = 0; K < Inner ; K += TileSize
|
||||
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU,
|
||||
LoopInfo &LI) {
|
||||
Loop *ColLoop = LI.AllocateLoop();
|
||||
Loop *RowLoop = LI.AllocateLoop();
|
||||
Loop *InnerLoop = LI.AllocateLoop();
|
||||
RowLoop->addChildLoop(InnerLoop);
|
||||
ColLoop->addChildLoop(RowLoop);
|
||||
if (Loop *ParentL = LI.getLoopFor(Start))
|
||||
ParentL->addChildLoop(ColLoop);
|
||||
else
|
||||
LI.addTopLevelLoop(ColLoop);
|
||||
|
||||
BasicBlock *ColBody =
|
||||
CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
|
||||
"cols", B, DTU, ColLoop, LI);
|
||||
BasicBlock *ColLatch = ColBody->getSingleSuccessor();
|
||||
BasicBlock *RowBody =
|
||||
CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
|
||||
"rows", B, DTU, RowLoop, LI);
|
||||
RowLoopLatch = RowBody->getSingleSuccessor();
|
||||
|
||||
BasicBlock *InnerBody =
|
||||
CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
|
||||
B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
|
||||
InnerLoopLatch = InnerBody->getSingleSuccessor();
|
||||
ColumnLoopHeader = ColBody->getSinglePredecessor();
|
||||
RowLoopHeader = RowBody->getSinglePredecessor();
|
||||
InnerLoopHeader = InnerBody->getSinglePredecessor();
|
||||
CurrentRow = &*RowLoopHeader->begin();
|
||||
CurrentCol = &*ColumnLoopHeader->begin();
|
||||
CurrentK = &*InnerLoopHeader->begin();
|
||||
|
||||
return InnerBody;
|
||||
}
|
Loading…
Reference in New Issue
Block a user