1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-21 18:22:53 +01:00

[CallGraph] Preserve call records vector when replacing call edge

Summary:
Try not to resize vector of call records in a call graph node when
replacing call edge. That would prevent invalidation of iterators
stored in the CG SCC pass manager's scc_iterator.

Reviewers: jdoerfert

Reviewed By: jdoerfert

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D84295
This commit is contained in:
Sergey Dmitriev 2020-07-27 06:02:06 -07:00
parent bb96b301c9
commit 02505ee23a
2 changed files with 115 additions and 5 deletions

View File

@ -281,13 +281,37 @@ void CallGraphNode::replaceCallEdge(CallBase &Call, CallBase &NewCall,
I->second = NewNode;
NewNode->AddRef();
// Refresh callback references.
forEachCallbackFunction(Call, [=](Function *CB) {
removeOneAbstractEdgeTo(CG->getOrInsertFunction(CB));
// Refresh callback references. Do not resize CalledFunctions if the
// number of callbacks is the same for new and old call sites.
SmallVector<CallGraphNode *, 4u> OldCBs;
SmallVector<CallGraphNode *, 4u> NewCBs;
forEachCallbackFunction(Call, [this, &OldCBs](Function *CB) {
OldCBs.push_back(CG->getOrInsertFunction(CB));
});
forEachCallbackFunction(NewCall, [=](Function *CB) {
addCalledFunction(nullptr, CG->getOrInsertFunction(CB));
forEachCallbackFunction(NewCall, [this, &NewCBs](Function *CB) {
NewCBs.push_back(CG->getOrInsertFunction(CB));
});
if (OldCBs.size() == NewCBs.size()) {
for (unsigned N = 0; N < OldCBs.size(); ++N) {
CallGraphNode *OldNode = OldCBs[N];
CallGraphNode *NewNode = NewCBs[N];
for (auto J = CalledFunctions.begin();; ++J) {
assert(J != CalledFunctions.end() &&
"Cannot find callsite to update!");
if (!J->first && J->second == OldNode) {
J->second = NewNode;
OldNode->DropRef();
NewNode->AddRef();
break;
}
}
}
} else {
for (auto *CGN : OldCBs)
removeOneAbstractEdgeTo(CGN);
for (auto *CGN : NewCBs)
addCalledFunction(nullptr, CGN);
}
return;
}
}

View File

@ -16,6 +16,8 @@
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/AbstractCallSite.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/DataLayout.h"
@ -28,6 +30,7 @@
#include "llvm/IR/OptBisect.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
#include "gtest/gtest.h"
@ -694,6 +697,89 @@ namespace llvm {
ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U);
ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U);
}
// Test for call graph SCC pass that replaces all callback call instructions
// with clones and updates CallGraph by calling CallGraph::replaceCallEdge()
// method. Test is expected to complete successfully after running pass on
// all SCCs in the test module.
struct CallbackCallsModifierPass : public CGPass {
bool runOnSCC(CallGraphSCC &SCC) override {
CGPass::run();
CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph());
bool Changed = false;
for (CallGraphNode *CGN : SCC) {
Function *F = CGN->getFunction();
if (!F || F->isDeclaration())
continue;
SmallVector<CallBase *, 4u> Calls;
for (Use &U : F->uses()) {
AbstractCallSite ACS(&U);
if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U))
continue;
Calls.push_back(cast<CallBase>(ACS.getInstruction()));
}
if (Calls.empty())
continue;
for (CallBase *OldCB : Calls) {
CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()];
assert(any_of(*CallerCGN,
[CGN](const CallGraphNode::CallRecord &CallRecord) {
return CallRecord.second == CGN;
}) &&
"function is not a callee");
CallBase *NewCB = cast<CallBase>(OldCB->clone());
NewCB->insertBefore(OldCB);
NewCB->takeName(OldCB);
CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]);
OldCB->replaceAllUsesWith(NewCB);
OldCB->eraseFromParent();
}
Changed = true;
}
return Changed;
}
};
TEST(PassManager, CallbackCallsModifier0) {
LLVMContext Context;
const char *IR = "define void @foo() {\n"
" call void @broker(void (i8*)* @callback0, i8* null)\n"
" call void @broker(void (i8*)* @callback1, i8* null)\n"
" ret void\n"
"}\n"
"\n"
"declare !callback !0 void @broker(void (i8*)*, i8*)\n"
"\n"
"define internal void @callback0(i8* %arg) {\n"
" ret void\n"
"}\n"
"\n"
"define internal void @callback1(i8* %arg) {\n"
" ret void\n"
"}\n"
"\n"
"!0 = !{!1}\n"
"!1 = !{i64 0, i64 1, i1 false}";
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Context);
if (!M)
Err.print("LegacyPassManagerTest", errs());
CallbackCallsModifierPass *P = new CallbackCallsModifierPass();
legacy::PassManager Passes;
Passes.add(P);
Passes.run(*M);
}
}
}