mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 02:33:06 +01:00
[CallGraph] Preserve call records vector when replacing call edge
Summary: Try not to resize vector of call records in a call graph node when replacing call edge. That would prevent invalidation of iterators stored in the CG SCC pass manager's scc_iterator. Reviewers: jdoerfert Reviewed By: jdoerfert Subscribers: hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D84295
This commit is contained in:
parent
bb96b301c9
commit
02505ee23a
@ -281,13 +281,37 @@ void CallGraphNode::replaceCallEdge(CallBase &Call, CallBase &NewCall,
|
|||||||
I->second = NewNode;
|
I->second = NewNode;
|
||||||
NewNode->AddRef();
|
NewNode->AddRef();
|
||||||
|
|
||||||
// Refresh callback references.
|
// Refresh callback references. Do not resize CalledFunctions if the
|
||||||
forEachCallbackFunction(Call, [=](Function *CB) {
|
// number of callbacks is the same for new and old call sites.
|
||||||
removeOneAbstractEdgeTo(CG->getOrInsertFunction(CB));
|
SmallVector<CallGraphNode *, 4u> OldCBs;
|
||||||
|
SmallVector<CallGraphNode *, 4u> NewCBs;
|
||||||
|
forEachCallbackFunction(Call, [this, &OldCBs](Function *CB) {
|
||||||
|
OldCBs.push_back(CG->getOrInsertFunction(CB));
|
||||||
});
|
});
|
||||||
forEachCallbackFunction(NewCall, [=](Function *CB) {
|
forEachCallbackFunction(NewCall, [this, &NewCBs](Function *CB) {
|
||||||
addCalledFunction(nullptr, CG->getOrInsertFunction(CB));
|
NewCBs.push_back(CG->getOrInsertFunction(CB));
|
||||||
});
|
});
|
||||||
|
if (OldCBs.size() == NewCBs.size()) {
|
||||||
|
for (unsigned N = 0; N < OldCBs.size(); ++N) {
|
||||||
|
CallGraphNode *OldNode = OldCBs[N];
|
||||||
|
CallGraphNode *NewNode = NewCBs[N];
|
||||||
|
for (auto J = CalledFunctions.begin();; ++J) {
|
||||||
|
assert(J != CalledFunctions.end() &&
|
||||||
|
"Cannot find callsite to update!");
|
||||||
|
if (!J->first && J->second == OldNode) {
|
||||||
|
J->second = NewNode;
|
||||||
|
OldNode->DropRef();
|
||||||
|
NewNode->AddRef();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (auto *CGN : OldCBs)
|
||||||
|
removeOneAbstractEdgeTo(CGN);
|
||||||
|
for (auto *CGN : NewCBs)
|
||||||
|
addCalledFunction(nullptr, CGN);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
#include "llvm/Analysis/CallGraphSCCPass.h"
|
#include "llvm/Analysis/CallGraphSCCPass.h"
|
||||||
#include "llvm/Analysis/LoopInfo.h"
|
#include "llvm/Analysis/LoopInfo.h"
|
||||||
#include "llvm/Analysis/LoopPass.h"
|
#include "llvm/Analysis/LoopPass.h"
|
||||||
|
#include "llvm/AsmParser/Parser.h"
|
||||||
|
#include "llvm/IR/AbstractCallSite.h"
|
||||||
#include "llvm/IR/BasicBlock.h"
|
#include "llvm/IR/BasicBlock.h"
|
||||||
#include "llvm/IR/CallingConv.h"
|
#include "llvm/IR/CallingConv.h"
|
||||||
#include "llvm/IR/DataLayout.h"
|
#include "llvm/IR/DataLayout.h"
|
||||||
@ -28,6 +30,7 @@
|
|||||||
#include "llvm/IR/OptBisect.h"
|
#include "llvm/IR/OptBisect.h"
|
||||||
#include "llvm/InitializePasses.h"
|
#include "llvm/InitializePasses.h"
|
||||||
#include "llvm/Support/MathExtras.h"
|
#include "llvm/Support/MathExtras.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
|
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
@ -694,6 +697,89 @@ namespace llvm {
|
|||||||
ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U);
|
ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U);
|
||||||
ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U);
|
ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test for call graph SCC pass that replaces all callback call instructions
|
||||||
|
// with clones and updates CallGraph by calling CallGraph::replaceCallEdge()
|
||||||
|
// method. Test is expected to complete successfully after running pass on
|
||||||
|
// all SCCs in the test module.
|
||||||
|
struct CallbackCallsModifierPass : public CGPass {
|
||||||
|
bool runOnSCC(CallGraphSCC &SCC) override {
|
||||||
|
CGPass::run();
|
||||||
|
|
||||||
|
CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph());
|
||||||
|
|
||||||
|
bool Changed = false;
|
||||||
|
for (CallGraphNode *CGN : SCC) {
|
||||||
|
Function *F = CGN->getFunction();
|
||||||
|
if (!F || F->isDeclaration())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<CallBase *, 4u> Calls;
|
||||||
|
for (Use &U : F->uses()) {
|
||||||
|
AbstractCallSite ACS(&U);
|
||||||
|
if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U))
|
||||||
|
continue;
|
||||||
|
Calls.push_back(cast<CallBase>(ACS.getInstruction()));
|
||||||
|
}
|
||||||
|
if (Calls.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (CallBase *OldCB : Calls) {
|
||||||
|
CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()];
|
||||||
|
assert(any_of(*CallerCGN,
|
||||||
|
[CGN](const CallGraphNode::CallRecord &CallRecord) {
|
||||||
|
return CallRecord.second == CGN;
|
||||||
|
}) &&
|
||||||
|
"function is not a callee");
|
||||||
|
|
||||||
|
CallBase *NewCB = cast<CallBase>(OldCB->clone());
|
||||||
|
|
||||||
|
NewCB->insertBefore(OldCB);
|
||||||
|
NewCB->takeName(OldCB);
|
||||||
|
|
||||||
|
CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]);
|
||||||
|
|
||||||
|
OldCB->replaceAllUsesWith(NewCB);
|
||||||
|
OldCB->eraseFromParent();
|
||||||
|
}
|
||||||
|
Changed = true;
|
||||||
|
}
|
||||||
|
return Changed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(PassManager, CallbackCallsModifier0) {
|
||||||
|
LLVMContext Context;
|
||||||
|
|
||||||
|
const char *IR = "define void @foo() {\n"
|
||||||
|
" call void @broker(void (i8*)* @callback0, i8* null)\n"
|
||||||
|
" call void @broker(void (i8*)* @callback1, i8* null)\n"
|
||||||
|
" ret void\n"
|
||||||
|
"}\n"
|
||||||
|
"\n"
|
||||||
|
"declare !callback !0 void @broker(void (i8*)*, i8*)\n"
|
||||||
|
"\n"
|
||||||
|
"define internal void @callback0(i8* %arg) {\n"
|
||||||
|
" ret void\n"
|
||||||
|
"}\n"
|
||||||
|
"\n"
|
||||||
|
"define internal void @callback1(i8* %arg) {\n"
|
||||||
|
" ret void\n"
|
||||||
|
"}\n"
|
||||||
|
"\n"
|
||||||
|
"!0 = !{!1}\n"
|
||||||
|
"!1 = !{i64 0, i64 1, i1 false}";
|
||||||
|
|
||||||
|
SMDiagnostic Err;
|
||||||
|
std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Context);
|
||||||
|
if (!M)
|
||||||
|
Err.print("LegacyPassManagerTest", errs());
|
||||||
|
|
||||||
|
CallbackCallsModifierPass *P = new CallbackCallsModifierPass();
|
||||||
|
legacy::PassManager Passes;
|
||||||
|
Passes.add(P);
|
||||||
|
Passes.run(*M);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user