1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-19 11:02:59 +02:00

[Coroutines] Part 7: Split coroutine into subfunctions

Summary:
This patch adds simple coroutine splitting logic to CoroSplit pass.

Documentation and overview is here: http://llvm.org/docs/Coroutines.html.

Upstreaming sequence (rough plan)
1.Add documentation. (https://reviews.llvm.org/D22603)
2.Add coroutine intrinsics. (https://reviews.llvm.org/D22659)
...
7. Split coroutine into subfunctions <= we are here
8. Coroutine Frame Building algorithm
9. Handle coroutine with unwinds
10+. The rest of the logic

Reviewers: majnemer

Subscribers: llvm-commits, mehdi_amini

Differential Revision: https://reviews.llvm.org/D23461

llvm-svn: 278830
This commit is contained in:
Gor Nishanov 2016-08-16 18:04:14 +00:00
parent acac95736d
commit 906e6267f2
11 changed files with 936 additions and 28 deletions

View File

@ -3,7 +3,8 @@ add_llvm_library(LLVMCoroutines
CoroCleanup.cpp CoroCleanup.cpp
CoroEarly.cpp CoroEarly.cpp
CoroElide.cpp CoroElide.cpp
CoroSplit.cpp CoroFrame.cpp
CoroSplit.cpp
) )
add_dependencies(LLVMCoroutines intrinsics_gen) add_dependencies(LLVMCoroutines intrinsics_gen)

View File

@ -10,12 +10,66 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "CoroInternal.h" #include "CoroInternal.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Pass.h" #include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
using namespace llvm; using namespace llvm;
#define DEBUG_TYPE "coro-cleanup" #define DEBUG_TYPE "coro-cleanup"
namespace {
// Created on demand if CoroCleanup pass has work to do.
struct Lowerer : coro::LowererBase {
Lowerer(Module &M) : LowererBase(M) {}
bool lowerRemainingCoroIntrinsics(Function &F);
};
}
static void simplifyCFG(Function &F) {
llvm::legacy::FunctionPassManager FPM(F.getParent());
FPM.add(createCFGSimplificationPass());
FPM.doInitialization();
FPM.run(F);
FPM.doFinalization();
}
bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) {
bool Changed = false;
for (auto IB = inst_begin(F), E = inst_end(F); IB != E;) {
Instruction &I = *IB++;
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
switch (II->getIntrinsicID()) {
default:
continue;
case Intrinsic::coro_begin:
II->replaceAllUsesWith(II->getArgOperand(1));
break;
case Intrinsic::coro_free:
II->replaceAllUsesWith(II->getArgOperand(0));
break;
case Intrinsic::coro_alloc:
II->replaceAllUsesWith(ConstantInt::getTrue(Context));
break;
case Intrinsic::coro_id:
II->replaceAllUsesWith(ConstantTokenNone::get(Context));
break;
}
II->eraseFromParent();
Changed = true;
}
}
if (Changed) {
// After replacement were made we can cleanup the function body a little.
simplifyCFG(F);
}
return Changed;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Top Level Driver // Top Level Driver
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -27,12 +81,27 @@ struct CoroCleanup : FunctionPass {
CoroCleanup() : FunctionPass(ID) {} CoroCleanup() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override { return false; } std::unique_ptr<Lowerer> L;
// This pass has work to do only if we find intrinsics we are going to lower
// in the module.
bool doInitialization(Module &M) override {
if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin",
"llvm.coro.free", "llvm.coro.id"}))
L = llvm::make_unique<Lowerer>(M);
return false;
}
bool runOnFunction(Function &F) override {
if (L)
return L->lowerRemainingCoroIntrinsics(F);
return false;
}
void getAnalysisUsage(AnalysisUsage &AU) const override { void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll(); if (!L)
AU.setPreservesAll();
} }
}; };
} }
char CoroCleanup::ID = 0; char CoroCleanup::ID = 0;

View File

@ -44,6 +44,16 @@ void Lowerer::lowerResumeOrDestroy(CallSite CS,
CS.setCallingConv(CallingConv::Fast); CS.setCallingConv(CallingConv::Fast);
} }
// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
// NoDuplicate attribute will be removed from coro.begin otherwise, it will
// interfere with inlining.
static void setCannotDuplicate(CoroIdInst *CoroId) {
for (User *U : CoroId->users())
if (auto *CB = dyn_cast<CoroBeginInst>(U))
CB->setCannotDuplicate();
}
bool Lowerer::lowerEarlyIntrinsics(Function &F) { bool Lowerer::lowerEarlyIntrinsics(Function &F) {
bool Changed = false; bool Changed = false;
for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) { for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
@ -52,12 +62,26 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) {
switch (CS.getIntrinsicID()) { switch (CS.getIntrinsicID()) {
default: default:
continue; continue;
case Intrinsic::coro_suspend:
// Make sure that final suspend point is not duplicated as CoroSplit
// pass expects that there is at most one final suspend point.
if (cast<CoroSuspendInst>(&I)->isFinal())
CS.setCannotDuplicate();
break;
case Intrinsic::coro_end:
// Make sure that fallthrough coro.end is not duplicated as CoroSplit
// pass expects that there is at most one fallthrough coro.end.
if (cast<CoroEndInst>(&I)->isFallthrough())
CS.setCannotDuplicate();
break;
case Intrinsic::coro_id: case Intrinsic::coro_id:
// Mark a function that comes out of the frontend that has a coro.begin // Mark a function that comes out of the frontend that has a coro.id
// with a coroutine attribute. // with a coroutine attribute.
if (auto *CII = cast<CoroIdInst>(&I)) { if (auto *CII = cast<CoroIdInst>(&I)) {
if (CII->getInfo().isPreSplit()) if (CII->getInfo().isPreSplit()) {
F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT); F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
setCannotDuplicate(CII);
}
} }
break; break;
case Intrinsic::coro_resume: case Intrinsic::coro_resume:
@ -88,8 +112,9 @@ struct CoroEarly : public FunctionPass {
// This pass has work to do only if we find intrinsics we are going to lower // This pass has work to do only if we find intrinsics we are going to lower
// in the module. // in the module.
bool doInitialization(Module &M) override { bool doInitialization(Module &M) override {
if (coro::declaresIntrinsics( if (coro::declaresIntrinsics(M, {"llvm.coro.begin", "llvm.coro.resume",
M, {"llvm.coro.begin", "llvm.coro.resume", "llvm.coro.destroy"})) "llvm.coro.destroy", "llvm.coro.suspend",
"llvm.coro.end"}))
L = llvm::make_unique<Lowerer>(M); L = llvm::make_unique<Lowerer>(M);
return false; return false;
} }

View File

@ -0,0 +1,104 @@
//===- CoroFrame.cpp - Builds and manipulates coroutine frame -------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// This file contains classes used to discover if for a particular value
// there from sue to definition that crosses a suspend block.
//
// Using the information discovered we form a Coroutine Frame structure to
// contain those values. All uses of those values are replaced with appropriate
// GEP + load from the coroutine frame. At the point of the definition we spill
// the value into the coroutine frame.
//
// TODO: pack values tightly using liveness info.
//===----------------------------------------------------------------------===//
#include "CoroInternal.h"
#include "llvm/IR/IRBuilder.h"
using namespace llvm;
// TODO: Implement in future patches.
struct SpillInfo {};
// Build a struct that will keep state for an active coroutine.
// struct f.frame {
// ResumeFnTy ResumeFnAddr;
// ResumeFnTy DestroyFnAddr;
// int ResumeIndex;
// ... promise (if present) ...
// ... spills ...
// };
static StructType *buildFrameType(Function &F, coro::Shape &Shape,
SpillInfo &Spills) {
LLVMContext &C = F.getContext();
SmallString<32> Name(F.getName());
Name.append(".Frame");
StructType *FrameTy = StructType::create(C, Name);
auto *FramePtrTy = FrameTy->getPointerTo();
auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
/*IsVarArgs=*/false);
auto *FnPtrTy = FnTy->getPointerTo();
if (Shape.CoroSuspends.size() > UINT32_MAX)
report_fatal_error("Cannot handle coroutine with this many suspend points");
SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, Type::getInt32Ty(C)};
// TODO: Populate from Spills.
FrameTy->setBody(Types);
return FrameTy;
}
// Replace all alloca and SSA values that are accessed across suspend points
// with GetElementPointer from coroutine frame + loads and stores. Create an
// AllocaSpillBB that will become the new entry block for the resume parts of
// the coroutine:
//
// %hdl = coro.begin(...)
// whatever
//
// becomes:
//
// %hdl = coro.begin(...)
// %FramePtr = bitcast i8* hdl to %f.frame*
// br label %AllocaSpillBB
//
// AllocaSpillBB:
// ; geps corresponding to allocas that were moved to coroutine frame
// br label PostSpill
//
// PostSpill:
// whatever
//
//
static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
auto *CB = Shape.CoroBegin;
IRBuilder<> Builder(CB->getNextNode());
PointerType *FramePtrTy = Shape.FrameTy->getPointerTo();
auto *FramePtr =
cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr"));
// TODO: Insert Spills.
auto *FramePtrBB = FramePtr->getParent();
Shape.AllocaSpillBlock =
FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB");
Shape.AllocaSpillBlock->splitBasicBlock(&Shape.AllocaSpillBlock->front(),
"PostSpill");
// TODO: Insert geps for alloca moved to coroutine frame.
return FramePtr;
}
void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
SpillInfo Spills;
// TODO: Compute Spills (incoming in later patches).
Shape.FrameTy = buildFrameType(F, Shape, Spills);
Shape.FramePtr = insertSpills(Spills, Shape);
}

View File

@ -172,4 +172,72 @@ public:
} }
}; };
/// This represents the llvm.coro.save instruction.
class LLVM_LIBRARY_VISIBILITY CoroSaveInst : public IntrinsicInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_save;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};
/// This represents the llvm.coro.suspend instruction.
class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst {
enum { SaveArg, FinalArg };
public:
CoroSaveInst *getCoroSave() const {
Value *Arg = getArgOperand(SaveArg);
if (auto *SI = dyn_cast<CoroSaveInst>(Arg))
return SI;
assert(isa<ConstantTokenNone>(Arg));
return nullptr;
}
bool isFinal() const {
return cast<Constant>(getArgOperand(FinalArg))->isOneValue();
}
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_suspend;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};
/// This represents the llvm.coro.size instruction.
class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_size;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};
/// This represents the llvm.coro.end instruction.
class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst {
enum { FrameArg, UnwindArg };
public:
bool isFallthrough() const { return !isUnwind(); }
bool isUnwind() const {
return cast<Constant>(getArgOperand(UnwindArg))->isOneValue();
}
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_end;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};
} // End namespace llvm. } // End namespace llvm.

View File

@ -17,6 +17,8 @@
namespace llvm { namespace llvm {
class CallGraph;
class CallGraphSCC;
class PassRegistry; class PassRegistry;
void initializeCoroEarlyPass(PassRegistry &); void initializeCoroEarlyPass(PassRegistry &);
@ -44,17 +46,49 @@ namespace coro {
bool declaresIntrinsics(Module &M, std::initializer_list<StringRef>); bool declaresIntrinsics(Module &M, std::initializer_list<StringRef>);
void replaceAllCoroAllocs(CoroBeginInst *CB, bool Replacement); void replaceAllCoroAllocs(CoroBeginInst *CB, bool Replacement);
void replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement); void replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement);
void updateCallGraph(Function &Caller, ArrayRef<Function *> Funcs,
CallGraph &CG, CallGraphSCC &SCC);
// Keeps data and helper functions for lowering coroutine intrinsics. // Keeps data and helper functions for lowering coroutine intrinsics.
struct LowererBase { struct LowererBase {
Module &TheModule; Module &TheModule;
LLVMContext &Context; LLVMContext &Context;
PointerType *const Int8Ptr;
FunctionType *const ResumeFnType; FunctionType *const ResumeFnType;
ConstantPointerNull *const NullPtr;
LowererBase(Module &M); LowererBase(Module &M);
Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt); Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt);
}; };
// Holds structural Coroutine Intrinsics for a particular function and other
// values used during CoroSplit pass.
struct LLVM_LIBRARY_VISIBILITY Shape {
CoroBeginInst *CoroBegin;
SmallVector<CoroEndInst *, 4> CoroEnds;
SmallVector<CoroSizeInst *, 2> CoroSizes;
SmallVector<CoroSuspendInst *, 4> CoroSuspends;
// Field Indexes for known coroutine frame fields.
enum {
ResumeField = 0,
DestroyField = 1,
IndexField = 2,
};
StructType *FrameTy;
Instruction *FramePtr;
BasicBlock* AllocaSpillBlock;
SwitchInst* ResumeSwitch;
bool HasFinalSuspend;
Shape() = default;
explicit Shape(Function &F) { buildFrom(F); }
void buildFrom(Function &F);
};
void buildCoroutineFrame(Function& F, Shape& Shape);
} // End namespace coro. } // End namespace coro.
} // End namespace llvm } // End namespace llvm

View File

@ -8,15 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// This pass builds the coroutine frame and outlines resume and destroy parts // This pass builds the coroutine frame and outlines resume and destroy parts
// of the coroutine into separate functions. // of the coroutine into separate functions.
//===----------------------------------------------------------------------===// //
#include "CoroInternal.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
using namespace llvm;
#define DEBUG_TYPE "coro-split"
// We present a coroutine to an LLVM as an ordinary function with suspension // We present a coroutine to an LLVM as an ordinary function with suspension
// points marked up with intrinsics. We let the optimizer party on the coroutine // points marked up with intrinsics. We let the optimizer party on the coroutine
// as a single function for as long as possible. Shortly before the coroutine is // as a single function for as long as possible. Shortly before the coroutine is
@ -25,6 +17,309 @@ using namespace llvm;
// add them to the current SCC and restart the IPO pipeline to optimize the // add them to the current SCC and restart the IPO pipeline to optimize the
// coroutine subfunctions we extracted before proceeding to the caller of the // coroutine subfunctions we extracted before proceeding to the caller of the
// coroutine. // coroutine.
//===----------------------------------------------------------------------===//
#include "CoroInternal.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
using namespace llvm;
#define DEBUG_TYPE "coro-split"
// Create an entry block for a resume function with a switch that will jump to
// suspend points.
static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
LLVMContext &C = F.getContext();
// resume.entry:
// %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
// i32 2
// % index = load i32, i32* %index.addr
// switch i32 %index, label %unreachable [
// i32 0, label %resume.0
// i32 1, label %resume.1
// ...
// ]
auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
IRBuilder<> Builder(NewEntry);
auto *FramePtr = Shape.FramePtr;
auto *FrameTy = Shape.FrameTy;
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
auto *Index = Builder.CreateLoad(GepIndex, "index");
auto *Switch =
Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
Shape.ResumeSwitch = Switch;
uint32_t SuspendIndex = 0;
for (auto S : Shape.CoroSuspends) {
ConstantInt *IndexVal = Builder.getInt32(SuspendIndex);
// Replace CoroSave with a store to Index:
// %index.addr = getelementptr %f.frame... (index field number)
// store i32 0, i32* %index.addr1
auto *Save = S->getCoroSave();
Builder.SetInsertPoint(Save);
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
Builder.CreateStore(IndexVal, GepIndex);
Save->replaceAllUsesWith(ConstantTokenNone::get(C));
Save->eraseFromParent();
// Split block before and after coro.suspend and add a jump from an entry
// switch:
//
// whateverBB:
// whatever
// %0 = call i8 @llvm.coro.suspend(token none, i1 false)
// switch i8 %0, label %suspend[i8 0, label %resume
// i8 1, label %cleanup]
// becomes:
//
// whateverBB:
// whatever
// br label %resume.0.landing
//
// resume.0: ; <--- jump from the switch in the resume.entry
// %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
// br label %resume.0.landing
//
// resume.0.landing:
// %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
// switch i8 % 1, label %suspend [i8 0, label %resume
// i8 1, label %cleanup]
auto *SuspendBB = S->getParent();
auto *ResumeBB =
SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
auto *LandingBB = ResumeBB->splitBasicBlock(
S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
Switch->addCase(IndexVal, ResumeBB);
cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
S->replaceAllUsesWith(PN);
PN->addIncoming(Builder.getInt8(-1), SuspendBB);
PN->addIncoming(S, ResumeBB);
++SuspendIndex;
}
Builder.SetInsertPoint(UnreachBB);
Builder.CreateUnreachable();
return NewEntry;
}
// In Resumers, we replace fallthrough coro.end with ret void and delete the
// rest of the block.
static void replaceFallthroughCoroEnd(IntrinsicInst *End,
ValueToValueMapTy &VMap) {
auto *NewE = cast<IntrinsicInst>(VMap[End]);
ReturnInst::Create(NewE->getContext(), nullptr, NewE);
// Remove the rest of the block, by splitting it into an unreachable block.
auto *BB = NewE->getParent();
BB->splitBasicBlock(NewE);
BB->getTerminator()->eraseFromParent();
}
// Create a resume clone by cloning the body of the original function, setting
// new entry block and replacing coro.suspend an appropriate value to force
// resume or cleanup pass for every suspend point.
static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
BasicBlock *ResumeEntry, int8_t FnIndex) {
Module *M = F.getParent();
auto *FrameTy = Shape.FrameTy;
auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
Function *NewF =
Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
F.getName() + Suffix, M);
NewF->addAttribute(1, Attribute::NonNull);
NewF->addAttribute(1, Attribute::NoAlias);
ValueToValueMapTy VMap;
// Replace all args with undefs. The buildCoroutineFrame algorithm already
// rewritten access to the args that occurs after suspend points with loads
// and stores to/from the coroutine frame.
for (Argument &A : F.getArgumentList())
VMap[&A] = UndefValue::get(A.getType());
SmallVector<ReturnInst *, 4> Returns;
CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
// If we have debug info, update it. ModuleLevelChanges = true above, does
// the heavy lifting, we just need to repoint subprogram at the same
// DICompileUnit as the original function F.
if (DISubprogram *SP = F.getSubprogram())
NewF->getSubprogram()->replaceUnit(SP->getUnit());
// Remove old returns.
for (ReturnInst *Return : Returns)
changeToUnreachable(Return, /*UseLLVMTrap=*/false);
// Remove old return attributes.
NewF->removeAttributes(
AttributeSet::ReturnIndex,
AttributeSet::get(
NewF->getContext(), AttributeSet::ReturnIndex,
AttributeFuncs::typeIncompatible(NewF->getReturnType())));
// Make AllocaSpillBlock the new entry block.
auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
Entry->moveBefore(&NewF->getEntryBlock());
Entry->getTerminator()->eraseFromParent();
BranchInst::Create(SwitchBB, Entry);
Entry->setName("entry" + Suffix);
// Clear all predecessors of the new entry block.
auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
Entry->replaceAllUsesWith(Switch->getDefaultDest());
IRBuilder<> Builder(&NewF->getEntryBlock().front());
// Remap frame pointer.
Argument *NewFramePtr = &NewF->getArgumentList().front();
Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
NewFramePtr->takeName(OldFramePtr);
OldFramePtr->replaceAllUsesWith(NewFramePtr);
// Remap vFrame pointer.
auto *NewVFrame = Builder.CreateBitCast(
NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
OldVFrame->replaceAllUsesWith(NewVFrame);
// Replace coro suspend with the appropriate resume index.
auto *NewValue = Builder.getInt8(FnIndex);
for (CoroSuspendInst *CS : Shape.CoroSuspends) {
auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
MappedCS->replaceAllUsesWith(NewValue);
MappedCS->eraseFromParent();
}
// Remove coro.end intrinsics.
replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
// FIXME: coming in upcoming patches:
// replaceUnwindCoroEnds(Shape.CoroEnds, VMap);
// Store the address of this clone in the coroutine frame.
Builder.SetInsertPoint(Shape.FramePtr->getNextNode());
auto *G = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, Shape.FramePtr, 0,
FnIndex, "fn.addr");
Builder.CreateStore(NewF, G);
NewF->setCallingConv(CallingConv::Fast);
return NewF;
}
static void removeCoroEnds(coro::Shape &Shape) {
for (CoroEndInst *CE : Shape.CoroEnds)
CE->eraseFromParent();
}
static void replaceFrameSize(coro::Shape &Shape) {
if (Shape.CoroSizes.empty())
return;
// In the same function all coro.sizes should have the same result type.
auto *SizeIntrin = Shape.CoroSizes.back();
Module *M = SizeIntrin->getModule();
const DataLayout &DL = M->getDataLayout();
auto Size = DL.getTypeAllocSize(Shape.FrameTy);
auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
for (CoroSizeInst *CS : Shape.CoroSizes) {
CS->replaceAllUsesWith(SizeConstant);
CS->eraseFromParent();
}
}
// Create a global constant array containing pointers to functions provided and
// set Info parameter of CoroBegin to point at this constant. Example:
//
// @f.resumers = internal constant [2 x void(%f.frame*)*]
// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
// define void @f() {
// ...
// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
//
// Assumes that all the functions have the same signature.
static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
std::initializer_list<Function *> Fns) {
SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
assert(!Args.empty());
Function *Part = *Fns.begin();
Module *M = Part->getParent();
auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
auto *ConstVal = ConstantArray::get(ArrTy, Args);
auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
GlobalVariable::PrivateLinkage, ConstVal,
F.getName() + Twine(".resumers"));
// Update coro.begin instruction to refer to this constant.
LLVMContext &C = F.getContext();
auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
CoroBegin->getId()->setInfo(BC);
}
static void postSplitCleanup(Function &F) {
removeUnreachableBlocks(F);
llvm::legacy::FunctionPassManager FPM(F.getParent());
FPM.add(createVerifierPass());
FPM.add(createSCCPPass());
FPM.add(createCFGSimplificationPass());
FPM.add(createEarlyCSEPass());
FPM.add(createCFGSimplificationPass());
FPM.doInitialization();
FPM.run(F);
FPM.doFinalization();
}
static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
coro::Shape Shape(F);
if (!Shape.CoroBegin)
return;
buildCoroutineFrame(F, Shape);
auto *ResumeEntry = createResumeEntryBlock(F, Shape);
auto *ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
auto *DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
// We no longer need coro.end in F.
removeCoroEnds(Shape);
postSplitCleanup(F);
postSplitCleanup(*ResumeClone);
postSplitCleanup(*DestroyClone);
replaceFrameSize(Shape);
setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone});
coro::updateCallGraph(F, {ResumeClone, DestroyClone}, CG, SCC);
}
// When we see the coroutine the first time, we insert an indirect call to a // When we see the coroutine the first time, we insert an indirect call to a
// devirt trigger function and mark the coroutine that it is now ready for // devirt trigger function and mark the coroutine that it is now ready for
@ -64,7 +359,7 @@ static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
LLVMContext &C = M.getContext(); LLVMContext &C = M.getContext();
auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
/*IsVarArgs=*/false); /*IsVarArgs=*/false);
Function *DevirtFn = Function *DevirtFn =
Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
CORO_DEVIRT_TRIGGER_FN, &M); CORO_DEVIRT_TRIGGER_FN, &M);
@ -125,12 +420,12 @@ struct CoroSplit : public CallGraphSCCPass {
continue; continue;
} }
F->removeFnAttr(CORO_PRESPLIT_ATTR); F->removeFnAttr(CORO_PRESPLIT_ATTR);
splitCoroutine(*F, CG, SCC);
} }
return true; return true;
} }
void getAnalysisUsage(AnalysisUsage &AU) const override { void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
CallGraphSCCPass::getAnalysisUsage(AU); CallGraphSCCPass::getAnalysisUsage(AU);
} }
}; };

View File

@ -10,11 +10,14 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "CoroInternal.h" #include "CoroInternal.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h" #include "llvm/IR/Verifier.h"
#include "llvm/InitializePasses.h" #include "llvm/InitializePasses.h"
#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm; using namespace llvm;
@ -70,9 +73,10 @@ void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
// Construct the lowerer base class and initialize its members. // Construct the lowerer base class and initialize its members.
coro::LowererBase::LowererBase(Module &M) coro::LowererBase::LowererBase(Module &M)
: TheModule(M), Context(M.getContext()), : TheModule(M), Context(M.getContext()),
ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr(Type::getInt8PtrTy(Context)),
Type::getInt8PtrTy(Context), ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
/*isVarArg=*/false)) {} /*isVarArg=*/false)),
NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
// Creates a sequence of instructions to obtain a resume function address using // Creates a sequence of instructions to obtain a resume function address using
// llvm.coro.subfn.addr. It generates the following sequence: // llvm.coro.subfn.addr. It generates the following sequence:
@ -122,3 +126,165 @@ bool coro::declaresIntrinsics(Module &M,
return false; return false;
} }
// FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
// happens to be private. It is better for this functionality exposed by the
// CallGraph.
static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
Function *F = Node->getFunction();
// Look for calls by this function.
for (Instruction &I : instructions(F))
if (CallSite CS = CallSite(cast<Value>(&I))) {
const Function *Callee = CS.getCalledFunction();
if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
// Indirect calls of intrinsics are not allowed so no need to check.
// We can be more precise here by using TargetArg returned by
// Intrinsic::isLeaf.
Node->addCalledFunction(CS, CG.getCallsExternalNode());
else if (!Callee->isIntrinsic())
Node->addCalledFunction(CS, CG.getOrInsertFunction(Callee));
}
}
// Rebuild CGN after we extracted parts of the code from ParentFunc into
// NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
CallGraph &CG, CallGraphSCC &SCC) {
// Rebuild CGN from scratch for the ParentFunc
auto *ParentNode = CG[&ParentFunc];
ParentNode->removeAllCalledFunctions();
buildCGN(CG, ParentNode);
SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
for (Function *F : NewFuncs) {
CallGraphNode *Callee = CG.getOrInsertFunction(F);
Nodes.push_back(Callee);
buildCGN(CG, Callee);
}
SCC.initialize(Nodes);
}
static void clear(coro::Shape &Shape) {
Shape.CoroBegin = nullptr;
Shape.CoroEnds.clear();
Shape.CoroSizes.clear();
Shape.CoroSuspends.clear();
Shape.FrameTy = nullptr;
Shape.FramePtr = nullptr;
Shape.AllocaSpillBlock = nullptr;
Shape.ResumeSwitch = nullptr;
Shape.HasFinalSuspend = false;
}
static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
CoroSuspendInst *SuspendInst) {
Module *M = SuspendInst->getModule();
auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
auto *SaveInst =
cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
assert(!SuspendInst->getCoroSave());
SuspendInst->setArgOperand(0, SaveInst);
return SaveInst;
}
// Collect "interesting" coroutine intrinsics.
void coro::Shape::buildFrom(Function &F) {
clear(*this);
SmallVector<CoroFrameInst *, 8> CoroFrames;
for (Instruction &I : instructions(F)) {
if (auto II = dyn_cast<IntrinsicInst>(&I)) {
switch (II->getIntrinsicID()) {
default:
continue;
case Intrinsic::coro_size:
CoroSizes.push_back(cast<CoroSizeInst>(II));
break;
case Intrinsic::coro_frame:
CoroFrames.push_back(cast<CoroFrameInst>(II));
break;
case Intrinsic::coro_suspend:
CoroSuspends.push_back(cast<CoroSuspendInst>(II));
// Make sure that the final suspend is the first suspend point in the
// CoroSuspends vector.
if (CoroSuspends.back()->isFinal()) {
HasFinalSuspend = true;
if (CoroSuspends.size() > 1) {
if (CoroSuspends.front()->isFinal())
report_fatal_error(
"Only one suspend point can be marked as final");
std::swap(CoroSuspends.front(), CoroSuspends.back());
}
}
break;
case Intrinsic::coro_begin: {
auto CB = cast<CoroBeginInst>(II);
if (CB->getId()->getInfo().isPreSplit()) {
if (CoroBegin)
report_fatal_error(
"coroutine should have exactly one defining @llvm.coro.begin");
CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias);
CB->removeAttribute(AttributeSet::FunctionIndex,
Attribute::NoDuplicate);
CoroBegin = CB;
}
break;
}
case Intrinsic::coro_end:
CoroEnds.push_back(cast<CoroEndInst>(II));
if (CoroEnds.back()->isFallthrough()) {
// Make sure that the fallthrough coro.end is the first element in the
// CoroEnds vector.
if (CoroEnds.size() > 1) {
if (CoroEnds.front()->isFallthrough())
report_fatal_error(
"Only one coro.end can be marked as fallthrough");
std::swap(CoroEnds.front(), CoroEnds.back());
}
}
break;
}
}
}
// If for some reason, we were not able to find coro.begin, bailout.
if (!CoroBegin) {
// Replace coro.frame which are supposed to be lowered to the result of
// coro.begin with undef.
auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(Undef);
CF->eraseFromParent();
}
// Replace all coro.suspend with undef and remove related coro.saves if
// present.
for (CoroSuspendInst *CS : CoroSuspends) {
CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
CS->eraseFromParent();
if (auto *CoroSave = CS->getCoroSave())
CoroSave->eraseFromParent();
}
// Replace all coro.ends with unreachable instruction.
for (CoroEndInst *CE : CoroEnds)
changeToUnreachable(CE, /*UseLLVMTrap=*/false);
return;
}
// The coro.free intrinsic is always lowered to the result of coro.begin.
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(CoroBegin);
CF->eraseFromParent();
}
// Canonicalize coro.suspend by inserting a coro.save if needed.
for (CoroSuspendInst *CS : CoroSuspends)
if (!CS->getCoroSave())
createCoroSave(CoroBegin, CoroSuspends.back());
}

View File

@ -0,0 +1,61 @@
; Tests that coro-split pass splits the coroutine into f, f.resume and f.destroy
; RUN: opt < %s -coro-split -S | FileCheck %s
define i8* @f() "coroutine.presplit"="1" {
entry:
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null)
%size = call i32 @llvm.coro.size.i32()
%alloc = call i8* @malloc(i32 %size)
%hdl = call i8* @llvm.coro.begin(token %id, i8* %alloc)
call void @print(i32 0)
%0 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %0, label %suspend [i8 0, label %resume
i8 1, label %cleanup]
resume:
call void @print(i32 1)
br label %cleanup
cleanup:
%mem = call i8* @llvm.coro.free(i8* %hdl)
call void @free(i8* %mem)
br label %suspend
suspend:
call void @llvm.coro.end(i8* %hdl, i1 0)
ret i8* %hdl
}
; CHECK-LABEL: @f(
; CHECK: call i8* @malloc
; CHECK: call void @print(i32 0)
; CHECK-NOT: call void @print(i32 1)
; CHECK-NOT: call void @free(
; CHECK: ret i8* %hdl
; CHECK-LABEL: @f.resume(
; CHECK-NOT: call i8* @malloc
; CHECK-NOT: call void @print(i32 0)
; CHECK: call void @print(i32 1)
; CHECK-NOT: call void @print(i32 0)
; CHECK: call void @free(
; CHECK: ret void
; CHECK-LABEL: @f.destroy(
; CHECK-NOT: call i8* @malloc
; CHECK-NOT: call void @print(
; CHECK: call void @free(
; CHECK: ret void
declare i8* @llvm.coro.free(i8*)
declare i32 @llvm.coro.size.i32()
declare i8 @llvm.coro.suspend(token, i1)
declare void @llvm.coro.resume(i8*)
declare void @llvm.coro.destroy(i8*)
declare token @llvm.coro.id(i32, i8*, i8*)
declare i8* @llvm.coro.alloc(token)
declare i8* @llvm.coro.begin(token, i8*)
declare void @llvm.coro.end(i8*, i1)
declare noalias i8* @malloc(i32)
declare void @print(i32)
declare void @free(i8*)

View File

@ -0,0 +1,60 @@
; Tests that a coroutine is split, inlined into the caller and devirtualized.
; RUN: opt < %s -S -enable-coroutines -O2 | FileCheck %s
define i8* @f() {
entry:
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null)
%need.dyn.alloc = call i1 @llvm.coro.alloc(token %id)
br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin
dyn.alloc:
%size = call i32 @llvm.coro.size.i32()
%alloc = call i8* @malloc(i32 %size)
br label %coro.begin
coro.begin:
%phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ]
%hdl = call i8* @llvm.coro.begin(token %id, i8* %phi)
call void @print(i32 0)
%0 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %0, label %suspend [i8 0, label %resume
i8 1, label %cleanup]
resume:
call void @print(i32 1)
br label %cleanup
cleanup:
%mem = call i8* @llvm.coro.free(i8* %hdl)
call void @free(i8* %mem)
br label %suspend
suspend:
call void @llvm.coro.end(i8* %hdl, i1 0)
ret i8* %hdl
}
define i32 @main() {
entry:
%hdl = call i8* @f()
call void @llvm.coro.resume(i8* %hdl)
ret i32 0
; CHECK-LABEL: @main(
; CHECK: call i8* @malloc
; CHECK-NOT: call void @free
; CHECK: call void @print(i32 0)
; CHECK-NOT: call void @free
; CHECK: call void @print(i32 1)
; CHECK: call void @free
; CHECK: ret i32 0
}
declare i8* @llvm.coro.free(i8*)
declare i32 @llvm.coro.size.i32()
declare i8 @llvm.coro.suspend(token, i1)
declare void @llvm.coro.resume(i8*)
declare void @llvm.coro.destroy(i8*)
declare token @llvm.coro.id(i32, i8*, i8*)
declare i1 @llvm.coro.alloc(token)
declare i8* @llvm.coro.begin(token, i8*)
declare void @llvm.coro.end(i8*, i1)
declare noalias i8* @malloc(i32)
declare void @print(i32)
declare void @free(i8*)

View File

@ -7,12 +7,37 @@
; CHECK: CoroSplit: Processing coroutine 'f' state: 0 ; CHECK: CoroSplit: Processing coroutine 'f' state: 0
; CHECK-NEXT: CoroSplit: Processing coroutine 'f' state: 1 ; CHECK-NEXT: CoroSplit: Processing coroutine 'f' state: 1
declare token @llvm.coro.id(i32, i8*, i8*)
declare i8* @llvm.coro.begin(token, i8*)
; a coroutine start function
define void @f() { define void @f() {
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null) %id = call token @llvm.coro.id(i32 0, i8* null, i8* null)
call i8* @llvm.coro.begin(token %id, i8* null) %size = call i32 @llvm.coro.size.i32()
ret void %alloc = call i8* @malloc(i32 %size)
%hdl = call i8* @llvm.coro.begin(token %id, i8* %alloc)
call void @print(i32 0)
%s1 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %s1, label %suspend [i8 0, label %resume
i8 1, label %cleanup]
resume:
call void @print(i32 1)
br label %cleanup
cleanup:
%mem = call i8* @llvm.coro.free(i8* %hdl)
call void @free(i8* %mem)
br label %suspend
suspend:
call void @llvm.coro.end(i8* %hdl, i1 0)
ret void
} }
declare token @llvm.coro.id(i32, i8*, i8*)
declare i8* @llvm.coro.begin(token, i8*)
declare i8* @llvm.coro.free(i8*)
declare i32 @llvm.coro.size.i32()
declare i8 @llvm.coro.suspend(token, i1)
declare void @llvm.coro.resume(i8*)
declare void @llvm.coro.destroy(i8*)
declare void @llvm.coro.end(i8*, i1)
declare noalias i8* @malloc(i32)
declare void @print(i32)
declare void @free(i8*)