diff --git a/include/llvm/ADT/DirectedGraph.h b/include/llvm/ADT/DirectedGraph.h new file mode 100644 index 00000000000..f6a358d99cd --- /dev/null +++ b/include/llvm/ADT/DirectedGraph.h @@ -0,0 +1,270 @@ +//===- llvm/ADT/DirectedGraph.h - Directed Graph ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface and a base class implementation for a +// directed graph. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_DIRECTEDGRAPH_H +#define LLVM_ADT_DIRECTEDGRAPH_H + +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { + +/// Represent an edge in the directed graph. +/// The edge contains the target node it connects to. +template class DGEdge { +public: + DGEdge() = delete; + /// Create an edge pointing to the given node \p N. + explicit DGEdge(NodeType &N) : TargetNode(N) {} + explicit DGEdge(const DGEdge &E) + : TargetNode(E.TargetNode) {} + DGEdge &operator=(const DGEdge &E) { + TargetNode = E.TargetNode; + return *this; + } + + /// Static polymorphism: delegate implementation (via isEqualTo) to the + /// derived class. + bool operator==(const EdgeType &E) const { return getDerived().isEqualTo(E); } + bool operator!=(const EdgeType &E) const { return !operator==(E); } + + /// Retrieve the target node this edge connects to. + const NodeType &getTargetNode() const { return TargetNode; } + NodeType &getTargetNode() { + return const_cast( + static_cast &>(*this).getTargetNode()); + } + +protected: + // As the default implementation use address comparison for equality. + bool isEqualTo(const EdgeType &E) const { return this == &E; } + + // Cast the 'this' pointer to the derived type and return a reference. + EdgeType &getDerived() { return *static_cast(this); } + const EdgeType &getDerived() const { + return *static_cast(this); + } + + // The target node this edge connects to. + NodeType &TargetNode; +}; + +/// Represent a node in the directed graph. +/// The node has a (possibly empty) list of outgoing edges. +template class DGNode { +public: + using EdgeListTy = SetVector; + using iterator = typename EdgeListTy::iterator; + using const_iterator = typename EdgeListTy::const_iterator; + + /// Create a node with a single outgoing edge \p E. + explicit DGNode(EdgeType &E) : Edges() { Edges.insert(&E); } + DGNode() = default; + + explicit DGNode(const DGNode &N) : Edges(N.Edges) {} + DGNode(DGNode &&N) : Edges(std::move(N.Edges)) {} + + DGNode &operator=(const DGNode &N) { + Edges = N.Edges; + return *this; + } + DGNode &operator=(const DGNode &&N) { + Edges = std::move(N.Edges); + return *this; + } + + /// Static polymorphism: delegate implementation (via isEqualTo) to the + /// derived class. + bool operator==(const NodeType &N) const { return getDerived().isEqualTo(N); } + bool operator!=(const NodeType &N) const { return !operator==(N); } + + const_iterator begin() const { return Edges.begin(); } + const_iterator end() const { return Edges.end(); } + iterator begin() { return Edges.begin(); } + iterator end() { return Edges.end(); } + const EdgeType &front() const { return *Edges.front(); } + EdgeType &front() { return *Edges.front(); } + const EdgeType &back() const { return *Edges.back(); } + EdgeType &back() { return *Edges.back(); } + + /// Collect in \p EL, all the edges from this node to \p N. + /// Return true if at least one edge was found, and false otherwise. + /// Note that this implementation allows more than one edge to connect + /// a given pair of nodes. + bool findEdgesTo(const NodeType &N, SmallVectorImpl &EL) const { + assert(EL.empty() && "Expected the list of edges to be empty."); + for (auto *E : Edges) + if (E->getTargetNode() == N) + EL.push_back(E); + return !EL.empty(); + } + + /// Add the given edge \p E to this node, if it doesn't exist already. Returns + /// true if the edge is added and false otherwise. + bool addEdge(EdgeType &E) { return Edges.insert(&E); } + + /// Remove the given edge \p E from this node, if it exists. + void removeEdge(EdgeType &E) { Edges.remove(&E); } + + /// Test whether there is an edge that goes from this node to \p N. + bool hasEdgeTo(const NodeType &N) const { + return (findEdgeTo(N) != Edges.end()); + } + + /// Retrieve the outgoing edges for the node. + const EdgeListTy &getEdges() const { return Edges; } + EdgeListTy &getEdges() { + return const_cast( + static_cast &>(*this).Edges); + } + + /// Clear the outgoing edges. + void clear() { Edges.clear(); } + +protected: + // As the default implementation use address comparison for equality. + bool isEqualTo(const NodeType &N) const { return this == &N; } + + // Cast the 'this' pointer to the derived type and return a reference. + NodeType &getDerived() { return *static_cast(this); } + const NodeType &getDerived() const { + return *static_cast(this); + } + + /// Find an edge to \p N. If more than one edge exists, this will return + /// the first one in the list of edges. + const_iterator findEdgeTo(const NodeType &N) const { + return llvm::find_if( + Edges, [&N](const EdgeType *E) { return E->getTargetNode() == N; }); + } + + // The list of outgoing edges. + EdgeListTy Edges; +}; + +/// Directed graph +/// +/// The graph is represented by a table of nodes. +/// Each node contains a (possibly empty) list of outgoing edges. +/// Each edge contains the target node it connects to. +template class DirectedGraph { +protected: + using NodeListTy = SmallVector; + using EdgeListTy = SmallVector; +public: + using iterator = typename NodeListTy::iterator; + using const_iterator = typename NodeListTy::const_iterator; + using DGraphType = DirectedGraph; + + DirectedGraph() = default; + explicit DirectedGraph(NodeType &N) : Nodes() { addNode(N); } + DirectedGraph(const DGraphType &G) : Nodes(G.Nodes) {} + DirectedGraph(DGraphType &&RHS) : Nodes(std::move(RHS.Nodes)) {} + DGraphType &operator=(const DGraphType &G) { + Nodes = G.Nodes; + return *this; + } + DGraphType &operator=(const DGraphType &&G) { + Nodes = std::move(G.Nodes); + return *this; + } + + const_iterator begin() const { return Nodes.begin(); } + const_iterator end() const { return Nodes.end(); } + iterator begin() { return Nodes.begin(); } + iterator end() { return Nodes.end(); } + const NodeType &front() const { return *Nodes.front(); } + NodeType &front() { return *Nodes.front(); } + const NodeType &back() const { return *Nodes.back(); } + NodeType &back() { return *Nodes.back(); } + + size_t size() const { return Nodes.size(); } + + /// Find the given node \p N in the table. + const_iterator findNode(const NodeType &N) const { + return llvm::find_if(Nodes, + [&N](const NodeType *Node) { return *Node == N; }); + } + iterator findNode(const NodeType &N) { + return const_cast( + static_cast(*this).findNode(N)); + } + + /// Add the given node \p N to the graph if it is not already present. + bool addNode(NodeType &N) { + if (findNode(N) != Nodes.end()) + return false; + Nodes.push_back(&N); + return true; + } + + /// Collect in \p EL all edges that are coming into node \p N. Return true + /// if at least one edge was found, and false otherwise. + bool findIncomingEdgesToNode(const NodeType &N, SmallVectorImpl &EL) const { + assert(EL.empty() && "Expected the list of edges to be empty."); + EdgeListTy TempList; + for (auto *Node : Nodes) { + if (*Node == N) + continue; + Node->findEdgesTo(N, TempList); + EL.insert(EL.end(), TempList.begin(), TempList.end()); + TempList.clear(); + } + return !EL.empty(); + } + + /// Remove the given node \p N from the graph. If the node has incoming or + /// outgoing edges, they are also removed. Return true if the node was found + /// and then removed, and false if the node was not found in the graph to + /// begin with. + bool removeNode(NodeType &N) { + iterator IT = findNode(N); + if (IT == Nodes.end()) + return false; + // Remove incoming edges. + EdgeListTy EL; + for (auto *Node : Nodes) { + if (*Node == N) + continue; + Node->findEdgesTo(N, EL); + for (auto *E : EL) + Node->removeEdge(*E); + EL.clear(); + } + N.clear(); + Nodes.erase(IT); + return true; + } + + /// Assuming nodes \p Src and \p Dst are already in the graph, connect node \p + /// Src to node \p Dst using the provided edge \p E. Return true if \p Src is + /// not already connected to \p Dst via \p E, and false otherwise. + bool connect(NodeType &Src, NodeType &Dst, EdgeType &E) { + assert(findNode(Src) != Nodes.end() && "Src node should be present."); + assert(findNode(Dst) != Nodes.end() && "Dst node should be present."); + assert((E.getTargetNode() == Dst) && + "Target of the given edge does not match Dst."); + return Src.addEdge(E); + } + +protected: + // The list of nodes in the graph. + NodeListTy Nodes; +}; + +} // namespace llvm + +#endif // LLVM_ADT_DIRECTEDGRAPH_H diff --git a/unittests/ADT/CMakeLists.txt b/unittests/ADT/CMakeLists.txt index 676ce181871..3a7be5b5522 100644 --- a/unittests/ADT/CMakeLists.txt +++ b/unittests/ADT/CMakeLists.txt @@ -17,6 +17,7 @@ add_llvm_unittest(ADTTests DenseMapTest.cpp DenseSetTest.cpp DepthFirstIteratorTest.cpp + DirectedGraphTest.cpp EquivalenceClassesTest.cpp FallibleIteratorTest.cpp FoldingSet.cpp diff --git a/unittests/ADT/DirectedGraphTest.cpp b/unittests/ADT/DirectedGraphTest.cpp new file mode 100644 index 00000000000..ae1f6b01ef2 --- /dev/null +++ b/unittests/ADT/DirectedGraphTest.cpp @@ -0,0 +1,295 @@ +//===- llvm/unittest/ADT/DirectedGraphTest.cpp ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines concrete derivations of the directed-graph base classes +// for testing purposes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DirectedGraph.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "gtest/gtest.h" + +namespace llvm { + +//===--------------------------------------------------------------------===// +// Derived nodes, edges and graph types based on DirectedGraph. +//===--------------------------------------------------------------------===// + +class DGTestNode; +class DGTestEdge; +using DGTestNodeBase = DGNode; +using DGTestEdgeBase = DGEdge; +using DGTestBase = DirectedGraph; + +class DGTestNode : public DGTestNodeBase { +public: + DGTestNode() = default; +}; + +class DGTestEdge : public DGTestEdgeBase { +public: + DGTestEdge() = delete; + DGTestEdge(DGTestNode &N) : DGTestEdgeBase(N) {} +}; + +class DGTestGraph : public DGTestBase { +public: + DGTestGraph() = default; + ~DGTestGraph(){}; +}; + +using EdgeListTy = SmallVector; + +//===--------------------------------------------------------------------===// +// GraphTraits specializations for the DGTest +//===--------------------------------------------------------------------===// + +template <> struct GraphTraits { + using NodeRef = DGTestNode *; + + static DGTestNode *DGTestGetTargetNode(DGEdge *P) { + return &P->getTargetNode(); + } + + // Provide a mapped iterator so that the GraphTrait-based implementations can + // find the target nodes without having to explicitly go through the edges. + using ChildIteratorType = + mapped_iterator; + using ChildEdgeIteratorType = DGTestNode::iterator; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N->begin(), &DGTestGetTargetNode); + } + static ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType(N->end(), &DGTestGetTargetNode); + } + + static ChildEdgeIteratorType child_edge_begin(NodeRef N) { + return N->begin(); + } + static ChildEdgeIteratorType child_edge_end(NodeRef N) { return N->end(); } +}; + +template <> +struct GraphTraits : public GraphTraits { + using nodes_iterator = DGTestGraph::iterator; + static NodeRef getEntryNode(DGTestGraph *DG) { return *DG->begin(); } + static nodes_iterator nodes_begin(DGTestGraph *DG) { return DG->begin(); } + static nodes_iterator nodes_end(DGTestGraph *DG) { return DG->end(); } +}; + +//===--------------------------------------------------------------------===// +// Test various modification and query functions. +//===--------------------------------------------------------------------===// + +TEST(DirectedGraphTest, AddAndConnectNodes) { + DGTestGraph DG; + DGTestNode N1, N2, N3; + DGTestEdge E1(N1), E2(N2), E3(N3); + + // Check that new nodes can be added successfully. + EXPECT_TRUE(DG.addNode(N1)); + EXPECT_TRUE(DG.addNode(N2)); + EXPECT_TRUE(DG.addNode(N3)); + + // Check that duplicate nodes are not added to the graph. + EXPECT_FALSE(DG.addNode(N1)); + + // Check that nodes can be connected using valid edges with no errors. + EXPECT_TRUE(DG.connect(N1, N2, E2)); + EXPECT_TRUE(DG.connect(N2, N3, E3)); + EXPECT_TRUE(DG.connect(N3, N1, E1)); + + // The graph looks like this now: + // + // +---------------+ + // v | + // N1 -> N2 -> N3 -+ + + // Check that already connected nodes with the given edge are not connected + // again (ie. edges are between nodes are not duplicated). + EXPECT_FALSE(DG.connect(N3, N1, E1)); + + // Check that there are 3 nodes in the graph. + EXPECT_TRUE(DG.size() == 3); + + // Check that the added nodes can be found in the graph. + EXPECT_NE(DG.findNode(N3), DG.end()); + + // Check that nodes that are not part of the graph are not found. + DGTestNode N4; + EXPECT_EQ(DG.findNode(N4), DG.end()); + + // Check that findIncommingEdgesToNode works correctly. + EdgeListTy EL; + EXPECT_TRUE(DG.findIncomingEdgesToNode(N1, EL)); + EXPECT_TRUE(EL.size() == 1); + EXPECT_EQ(*EL[0], E1); +} + +TEST(DirectedGraphTest, AddRemoveEdge) { + DGTestGraph DG; + DGTestNode N1, N2, N3; + DGTestEdge E1(N1), E2(N2), E3(N3); + DG.addNode(N1); + DG.addNode(N2); + DG.addNode(N3); + DG.connect(N1, N2, E2); + DG.connect(N2, N3, E3); + DG.connect(N3, N1, E1); + + // The graph looks like this now: + // + // +---------------+ + // v | + // N1 -> N2 -> N3 -+ + + // Check that there are 3 nodes in the graph. + EXPECT_TRUE(DG.size() == 3); + + // Check that the target nodes of the edges are correct. + EXPECT_EQ(E1.getTargetNode(), N1); + EXPECT_EQ(E2.getTargetNode(), N2); + EXPECT_EQ(E3.getTargetNode(), N3); + + // Remove the edge from N1 to N2. + N1.removeEdge(E2); + + // The graph looks like this now: + // + // N2 -> N3 -> N1 + + // Check that there are no incoming edges to N2. + EdgeListTy EL; + EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL)); + EXPECT_TRUE(EL.empty()); + + // Put the edge from N1 to N2 back in place. + N1.addEdge(E2); + + // Check that E2 is the only incoming edge to N2. + EL.clear(); + EXPECT_TRUE(DG.findIncomingEdgesToNode(N2, EL)); + EXPECT_EQ(*EL[0], E2); +} + +TEST(DirectedGraphTest, hasEdgeTo) { + DGTestGraph DG; + DGTestNode N1, N2, N3; + DGTestEdge E1(N1), E2(N2), E3(N3), E4(N1); + DG.addNode(N1); + DG.addNode(N2); + DG.addNode(N3); + DG.connect(N1, N2, E2); + DG.connect(N2, N3, E3); + DG.connect(N3, N1, E1); + DG.connect(N2, N1, E4); + + // The graph looks like this now: + // + // +-----+ + // v | + // N1 -> N2 -> N3 + // ^ | + // +-----------+ + + EXPECT_TRUE(N2.hasEdgeTo(N1)); + EXPECT_TRUE(N3.hasEdgeTo(N1)); +} + +TEST(DirectedGraphTest, AddRemoveNode) { + DGTestGraph DG; + DGTestNode N1, N2, N3; + DGTestEdge E1(N1), E2(N2), E3(N3); + DG.addNode(N1); + DG.addNode(N2); + DG.addNode(N3); + DG.connect(N1, N2, E2); + DG.connect(N2, N3, E3); + DG.connect(N3, N1, E1); + + // The graph looks like this now: + // + // +---------------+ + // v | + // N1 -> N2 -> N3 -+ + + // Check that there are 3 nodes in the graph. + EXPECT_TRUE(DG.size() == 3); + + // Check that a node in the graph can be removed, but not more than once. + EXPECT_TRUE(DG.removeNode(N1)); + EXPECT_EQ(DG.findNode(N1), DG.end()); + EXPECT_FALSE(DG.removeNode(N1)); + + // The graph looks like this now: + // + // N2 -> N3 + + // Check that there are 2 nodes in the graph and only N2 is connected to N3. + EXPECT_TRUE(DG.size() == 2); + EXPECT_TRUE(N3.getEdges().empty()); + EdgeListTy EL; + EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL)); + EXPECT_TRUE(EL.empty()); +} + +TEST(DirectedGraphTest, SCC) { + + DGTestGraph DG; + DGTestNode N1, N2, N3, N4; + DGTestEdge E1(N1), E2(N2), E3(N3), E4(N4); + DG.addNode(N1); + DG.addNode(N2); + DG.addNode(N3); + DG.addNode(N4); + DG.connect(N1, N2, E2); + DG.connect(N2, N3, E3); + DG.connect(N3, N1, E1); + DG.connect(N3, N4, E4); + + // The graph looks like this now: + // + // +---------------+ + // v | + // N1 -> N2 -> N3 -+ N4 + // | ^ + // +--------+ + + // Test that there are two SCCs: + // 1. {N1, N2, N3} + // 2. {N4} + using NodeListTy = SmallPtrSet; + SmallVector ListOfSCCs; + for (auto &SCC : make_range(scc_begin(&DG), scc_end(&DG))) + ListOfSCCs.push_back(NodeListTy(SCC.begin(), SCC.end())); + + EXPECT_TRUE(ListOfSCCs.size() == 2); + + for (auto &SCC : ListOfSCCs) { + if (SCC.size() > 1) + continue; + EXPECT_TRUE(SCC.size() == 1); + EXPECT_TRUE(SCC.count(&N4) == 1); + } + for (auto &SCC : ListOfSCCs) { + if (SCC.size() <= 1) + continue; + EXPECT_TRUE(SCC.size() == 3); + EXPECT_TRUE(SCC.count(&N1) == 1); + EXPECT_TRUE(SCC.count(&N2) == 1); + EXPECT_TRUE(SCC.count(&N3) == 1); + EXPECT_TRUE(SCC.count(&N4) == 0); + } +} + +} // namespace llvm