diff --git a/include/llvm/Analysis/CallGraph.h b/include/llvm/Analysis/CallGraph.h index c5687def3eb..6c9e56e2208 100644 --- a/include/llvm/Analysis/CallGraph.h +++ b/include/llvm/Analysis/CallGraph.h @@ -495,6 +495,56 @@ struct GraphTraits : public GraphTraits< } }; +// FIXME: The traits here are not limited to callgraphs and can be moved +// elsewhere including GraphTraits. They are left here because only algorithms +// that operate on Callgraphs currently use them. If other algorithms operating +// on a general graph need edge traversals, these can be moved. +template +struct CallGraphTraits : public GraphTraits { + // Elements to provide: + + // typedef EdgeRef - Type of Edge token in the graph, which should + // be cheap to copy. + // typedef CallEdgeIteratorType - Type used to iterate over children edges in + // graph, dereference to a EdgeRef. + + // static CallEdgeIteratorType call_edge_begin(NodeRef) + // static CallEdgeIteratorType call_edge_end (NodeRef) + // Return iterators that point to the beginning and ending of the call + // edges list for the given callgraph node. + // + // static NodeRef edge_dest(EdgeRef) + // Return the destination node of an edge. + + // If anyone tries to use this class without having an appropriate + // specialization, make an error. If you get this error, it's because you + // need to include the appropriate specialization of GraphTraits<> for your + // graph, or you need to define it for a new graph type. Either that or + // your argument to XXX_begin(...) is unknown or needs to have the proper .h + // file #include'd. + using CallEdgeIteratorType = + typename CallGraphType::UnknownCallGraphTypeError; +}; + +template +iterator_range::CallEdgeIteratorType> +call_edges(const typename CallGraphTraits::NodeRef &G) { + return make_range(CallGraphTraits::call_edge_begin(G), + CallGraphTraits::call_edge_end(G)); +} + +template <> +struct CallGraphTraits + : public GraphTraits { + using EdgeRef = const CallGraphNode::CallRecord &; + using CallEdgeIteratorType = CallGraphNode::const_iterator; + + static CallEdgeIteratorType call_edge_begin(NodeRef N) { return N->begin(); } + static CallEdgeIteratorType call_edge_end(NodeRef N) { return N->end(); } + + static NodeRef edge_dest(EdgeRef E) { return E.second; } +}; + } // end namespace llvm #endif // LLVM_ANALYSIS_CALLGRAPH_H diff --git a/include/llvm/Analysis/SyntheticCountsUtils.h b/include/llvm/Analysis/SyntheticCountsUtils.h index b0848eaee43..7b633c0b53f 100644 --- a/include/llvm/Analysis/SyntheticCountsUtils.h +++ b/include/llvm/Analysis/SyntheticCountsUtils.h @@ -15,7 +15,7 @@ #define LLVM_ANALYSIS_SYNTHETIC_COUNTS_UTILS_H #include "llvm/ADT/STLExtras.h" -#include "llvm/IR/CallSite.h" +#include "llvm/Analysis/CallGraph.h" #include "llvm/Support/ScaledNumber.h" namespace llvm { @@ -23,11 +23,30 @@ namespace llvm { class CallGraph; class Function; -using Scaled64 = ScaledNumber; -void propagateSyntheticCounts( - const CallGraph &CG, function_ref GetCallSiteRelFreq, - function_ref GetCount, - function_ref AddToCount); +/// Class with methods to propagate synthetic entry counts. +/// +/// This class is templated on the type of the call graph and designed to work +/// with the traditional per-module callgraph and the summary callgraphs used in +/// ThinLTO. This contains only static methods and alias templates. +template class SyntheticCountsUtils { +public: + using Scaled64 = ScaledNumber; + using CGT = CallGraphTraits; + using NodeRef = typename CGT::NodeRef; + using EdgeRef = typename CGT::EdgeRef; + using SccTy = std::vector; + + using GetRelBBFreqTy = function_ref(EdgeRef)>; + using GetCountTy = function_ref; + using AddCountTy = function_ref; + + static void propagate(const CallGraphType &CG, GetRelBBFreqTy GetRelBBFreq, + GetCountTy GetCount, AddCountTy AddCount); + +private: + static void propagateFromSCC(const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, + GetCountTy GetCount, AddCountTy AddCount); +}; } // namespace llvm #endif diff --git a/lib/Analysis/SyntheticCountsUtils.cpp b/lib/Analysis/SyntheticCountsUtils.cpp index 262299c5f3b..22b402b21e6 100644 --- a/lib/Analysis/SyntheticCountsUtils.cpp +++ b/lib/Analysis/SyntheticCountsUtils.cpp @@ -23,100 +23,91 @@ using namespace llvm; -// Given a set of functions in an SCC, propagate entry counts to functions -// called by the SCC. -static void -propagateFromSCC(const SmallPtrSetImpl &SCCFunctions, - function_ref GetCallSiteRelFreq, - function_ref GetCount, - function_ref AddToCount) { +// Given an SCC, propagate entry counts along the edge of the SCC nodes. +template +void SyntheticCountsUtils::propagateFromSCC( + const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, + AddCountTy AddCount) { - SmallVector CallSites; + SmallPtrSet SCCNodes; + SmallVector, 8> SCCEdges, NonSCCEdges; - // Gather all callsites in the SCC. - auto GatherCallSites = [&]() { - for (auto *F : SCCFunctions) { - assert(F && !F->isDeclaration()); - for (auto &I : instructions(F)) { - if (auto CS = CallSite(&I)) { - CallSites.push_back(CS); - } - } + for (auto &Node : SCC) + SCCNodes.insert(Node); + + // Partition the edges coming out of the SCC into those whose destination is + // in the SCC and the rest. + for (const auto &Node : SCCNodes) { + for (auto &E : call_edges(Node)) { + if (SCCNodes.count(CGT::edge_dest(E))) + SCCEdges.emplace_back(Node, E); + else + NonSCCEdges.emplace_back(Node, E); } - }; + } - GatherCallSites(); - - // Partition callsites so that the callsites that call functions in the same - // SCC come first. - auto Mid = partition(CallSites, [&](CallSite &CS) { - auto *Callee = CS.getCalledFunction(); - if (Callee) - return SCCFunctions.count(Callee); - // FIXME: Use the !callees metadata to propagate counts through indirect - // calls. - return 0U; - }); - - // For functions in the same SCC, update the counts in two steps: - // 1. Compute the additional count for each function by propagating the counts - // along all incoming edges to the function that originate from the same SCC - // and summing them up. - // 2. Add the additional counts to the functions in the SCC. + // For nodes in the same SCC, update the counts in two steps: + // 1. Compute the additional count for each node by propagating the counts + // along all incoming edges to the node that originate from within the same + // SCC and summing them up. + // 2. Add the additional counts to the nodes in the SCC. // This ensures that the order of - // traversal of functions within the SCC doesn't change the final result. + // traversal of nodes within the SCC doesn't affect the final result. - DenseMap AdditionalCounts; - for (auto It = CallSites.begin(); It != Mid; It++) { - auto &CS = *It; - auto RelFreq = GetCallSiteRelFreq(CS); - Function *Callee = CS.getCalledFunction(); - Function *Caller = CS.getCaller(); + DenseMap AdditionalCounts; + for (auto &E : SCCEdges) { + auto OptRelFreq = GetRelBBFreq(E.second); + if (!OptRelFreq) + continue; + Scaled64 RelFreq = OptRelFreq.getValue(); + auto Caller = E.first; + auto Callee = CGT::edge_dest(E.second); RelFreq *= Scaled64(GetCount(Caller), 0); uint64_t AdditionalCount = RelFreq.toInt(); AdditionalCounts[Callee] += AdditionalCount; } - // Update the counts for the functions in the SCC. + // Update the counts for the nodes in the SCC. for (auto &Entry : AdditionalCounts) - AddToCount(Entry.first, Entry.second); + AddCount(Entry.first, Entry.second); - // Now update the counts for functions not in SCC. - for (auto It = Mid; It != CallSites.end(); It++) { - auto &CS = *It; - auto Weight = GetCallSiteRelFreq(CS); - Function *Callee = CS.getCalledFunction(); - Function *Caller = CS.getCaller(); - Weight *= Scaled64(GetCount(Caller), 0); - AddToCount(Callee, Weight.toInt()); + // Now update the counts for nodes outside the SCC. + for (auto &E : NonSCCEdges) { + auto OptRelFreq = GetRelBBFreq(E.second); + if (!OptRelFreq) + continue; + Scaled64 RelFreq = OptRelFreq.getValue(); + auto Caller = E.first; + auto Callee = CGT::edge_dest(E.second); + RelFreq *= Scaled64(GetCount(Caller), 0); + AddCount(Callee, RelFreq.toInt()); } } -/// Propgate synthetic entry counts on a callgraph. +/// Propgate synthetic entry counts on a callgraph \p CG. /// /// This performs a reverse post-order traversal of the callgraph SCC. For each -/// SCC, it first propagates the entry counts to the functions within the SCC +/// SCC, it first propagates the entry counts to the nodes within the SCC /// through call edges and updates them in one shot. Then the entry counts are -/// propagated to functions outside the SCC. -void llvm::propagateSyntheticCounts( - const CallGraph &CG, function_ref GetCallSiteRelFreq, - function_ref GetCount, - function_ref AddToCount) { +/// propagated to nodes outside the SCC. This requires \p CallGraphTraits +/// to have a specialization for \p CallGraphType. - SmallVector, 16> SCCs; - for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) { - auto SCC = *I; +template +void SyntheticCountsUtils::propagate(const CallGraphType &CG, + GetRelBBFreqTy GetRelBBFreq, + GetCountTy GetCount, + AddCountTy AddCount) { + std::vector SCCs; - SmallPtrSet SCCFunctions; - for (auto *Node : SCC) { - Function *F = Node->getFunction(); - if (F && !F->isDeclaration()) { - SCCFunctions.insert(F); - } - } - SCCs.push_back(SCCFunctions); - } + // Collect all the SCCs. + for (auto I = scc_begin(CG); !I.isAtEnd(); ++I) + SCCs.push_back(*I); - for (auto &SCCFunctions : reverse(SCCs)) - propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount); + // The callgraph-scc needs to be visited in top-down order for propagation. + // The scc iterator returns the scc in bottom-up order, so reverse the SCCs + // and call propagateFromSCC. + for (auto &SCC : reverse(SCCs)) + propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); } + +template class llvm::SyntheticCountsUtils; diff --git a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index f599adfe779..3c5ad37bced 100644 --- a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -102,23 +102,34 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, // Set initial entry counts. initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); - // Compute the relative block frequency for a callsite. Use scaled numbers + // Compute the relative block frequency for a call edge. Use scaled numbers // and not integers since the relative block frequency could be less than 1. - auto GetCallSiteRelFreq = [&](CallSite CS) { + auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { + Optional Res = None; + if (!Edge.first) + return Res; + assert(isa(Edge.first)); + CallSite CS(cast(Edge.first)); Function *Caller = CS.getCaller(); auto &BFI = FAM.getResult(*Caller); BasicBlock *CSBB = CS.getInstruction()->getParent(); Scaled64 EntryFreq(BFI.getEntryFreq(), 0); Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); BBFreq /= EntryFreq; - return BBFreq; + return Optional(BBFreq); }; CallGraph CG(M); // Propgate the entry counts on the callgraph. - propagateSyntheticCounts( - CG, GetCallSiteRelFreq, [&](Function *F) { return Counts[F]; }, - [&](Function *F, uint64_t New) { Counts[F] += New; }); + SyntheticCountsUtils::propagate( + &CG, GetCallSiteRelFreq, + [&](const CallGraphNode *N) { return Counts[N->getFunction()]; }, + [&](const CallGraphNode *N, uint64_t New) { + auto F = N->getFunction(); + if (!F || F->isDeclaration()) + return; + Counts[F] += New; + }); // Set the counts as metadata. for (auto Entry : Counts)