//===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file A bitvector that uses an IntervalMap to coalesce adjacent elements /// into intervals. /// //===----------------------------------------------------------------------===// #ifndef LLVM_ADT_COALESCINGBITVECTOR_H #define LLVM_ADT_COALESCINGBITVECTOR_H #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #include namespace llvm { /// A bitvector that, under the hood, relies on an IntervalMap to coalesce /// elements into intervals. Good for representing sets which predominantly /// contain contiguous ranges. Bad for representing sets with lots of gaps /// between elements. /// /// Compared to SparseBitVector, CoalescingBitVector offers more predictable /// performance for non-sequential find() operations. /// /// \tparam IndexT - The type of the index into the bitvector. template class CoalescingBitVector { static_assert(std::is_unsigned::value, "Index must be an unsigned integer."); using ThisT = CoalescingBitVector; /// An interval map for closed integer ranges. The mapped values are unused. using MapT = IntervalMap; using UnderlyingIterator = typename MapT::const_iterator; using IntervalT = std::pair; public: using Allocator = typename MapT::Allocator; /// Construct by passing in a CoalescingBitVector::Allocator /// reference. CoalescingBitVector(Allocator &Alloc) : Alloc(&Alloc), Intervals(Alloc) {} /// \name Copy/move constructors and assignment operators. /// @{ CoalescingBitVector(const ThisT &Other) : Alloc(Other.Alloc), Intervals(*Other.Alloc) { set(Other); } ThisT &operator=(const ThisT &Other) { clear(); set(Other); return *this; } CoalescingBitVector(ThisT &&Other) = delete; ThisT &operator=(ThisT &&Other) = delete; /// @} /// Clear all the bits. void clear() { Intervals.clear(); } /// Check whether no bits are set. bool empty() const { return Intervals.empty(); } /// Count the number of set bits. unsigned count() const { unsigned Bits = 0; for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It) Bits += 1 + It.stop() - It.start(); return Bits; } /// Set the bit at \p Index. /// /// This method does /not/ support setting a bit that has already been set, /// for efficiency reasons. If possible, restructure your code to not set the /// same bit multiple times, or use \ref test_and_set. void set(IndexT Index) { assert(!test(Index) && "Setting already-set bits not supported/efficient, " "IntervalMap will assert"); insert(Index, Index); } /// Set the bits set in \p Other. /// /// This method does /not/ support setting already-set bits, see \ref set /// for the rationale. For a safe set union operation, use \ref operator|=. void set(const ThisT &Other) { for (auto It = Other.Intervals.begin(), End = Other.Intervals.end(); It != End; ++It) insert(It.start(), It.stop()); } /// Set the bits at \p Indices. Used for testing, primarily. void set(std::initializer_list Indices) { for (IndexT Index : Indices) set(Index); } /// Check whether the bit at \p Index is set. bool test(IndexT Index) const { const auto It = Intervals.find(Index); if (It == Intervals.end()) return false; assert(It.stop() >= Index && "Interval must end after Index"); return It.start() <= Index; } /// Set the bit at \p Index. Supports setting an already-set bit. void test_and_set(IndexT Index) { if (!test(Index)) set(Index); } /// Reset the bit at \p Index. Supports resetting an already-unset bit. void reset(IndexT Index) { auto It = Intervals.find(Index); if (It == Intervals.end()) return; // Split the interval containing Index into up to two parts: one from // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to // either Start or Stop, we create one new interval. If Index is equal to // both Start and Stop, we simply erase the existing interval. IndexT Start = It.start(); if (Index < Start) // The index was not set. return; IndexT Stop = It.stop(); assert(Index <= Stop && "Wrong interval for index"); It.erase(); if (Start < Index) insert(Start, Index - 1); if (Index < Stop) insert(Index + 1, Stop); } /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may /// be a faster alternative. void operator|=(const ThisT &RHS) { // Get the overlaps between the two interval maps. SmallVector Overlaps; getOverlaps(RHS, Overlaps); // Insert the non-overlapping parts of all the intervals from RHS. for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end(); It != End; ++It) { IndexT Start = It.start(); IndexT Stop = It.stop(); SmallVector NonOverlappingParts; getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts); for (IntervalT AdditivePortion : NonOverlappingParts) insert(AdditivePortion.first, AdditivePortion.second); } } /// Set intersection. void operator&=(const ThisT &RHS) { // Get the overlaps between the two interval maps (i.e. the intersection). SmallVector Overlaps; getOverlaps(RHS, Overlaps); // Rebuild the interval map, including only the overlaps. clear(); for (IntervalT Overlap : Overlaps) insert(Overlap.first, Overlap.second); } /// Reset all bits present in \p Other. void intersectWithComplement(const ThisT &Other) { SmallVector Overlaps; if (!getOverlaps(Other, Overlaps)) { // If there is no overlap with Other, the intersection is empty. return; } // Delete the overlapping intervals. Split up intervals that only partially // intersect an overlap. for (IntervalT Overlap : Overlaps) { IndexT OlapStart, OlapStop; std::tie(OlapStart, OlapStop) = Overlap; auto It = Intervals.find(OlapStart); IndexT CurrStart = It.start(); IndexT CurrStop = It.stop(); assert(CurrStart <= OlapStart && OlapStop <= CurrStop && "Expected some intersection!"); // Split the overlap interval into up to two parts: one from [CurrStart, // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is // equal to CurrStart, the first split interval is unnecessary. Ditto for // when OlapStop is equal to CurrStop, we omit the second split interval. It.erase(); if (CurrStart < OlapStart) insert(CurrStart, OlapStart - 1); if (OlapStop < CurrStop) insert(OlapStop + 1, CurrStop); } } bool operator==(const ThisT &RHS) const { // We cannot just use std::equal because it checks the dereferenced values // of an iterator pair for equality, not the iterators themselves. In our // case that results in comparison of the (unused) IntervalMap values. auto ItL = Intervals.begin(); auto ItR = RHS.Intervals.begin(); while (ItL != Intervals.end() && ItR != RHS.Intervals.end() && ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) { ++ItL; ++ItR; } return ItL == Intervals.end() && ItR == RHS.Intervals.end(); } bool operator!=(const ThisT &RHS) const { return !operator==(RHS); } class const_iterator { friend class CoalescingBitVector; public: using iterator_category = std::forward_iterator_tag; using value_type = IndexT; using difference_type = std::ptrdiff_t; using pointer = value_type *; using reference = value_type &; private: // For performance reasons, make the offset at the end different than the // one used in \ref begin, to optimize the common `It == end()` pattern. static constexpr unsigned kIteratorAtTheEndOffset = ~0u; UnderlyingIterator MapIterator; unsigned OffsetIntoMapIterator = 0; // Querying the start/stop of an IntervalMap iterator can be very expensive. // Cache these values for performance reasons. IndexT CachedStart = IndexT(); IndexT CachedStop = IndexT(); void setToEnd() { OffsetIntoMapIterator = kIteratorAtTheEndOffset; CachedStart = IndexT(); CachedStop = IndexT(); } /// MapIterator has just changed, reset the cached state to point to the /// start of the new underlying iterator. void resetCache() { if (MapIterator.valid()) { OffsetIntoMapIterator = 0; CachedStart = MapIterator.start(); CachedStop = MapIterator.stop(); } else { setToEnd(); } } /// Advance the iterator to \p Index, if it is contained within the current /// interval. The public-facing method which supports advancing past the /// current interval is \ref advanceToLowerBound. void advanceTo(IndexT Index) { assert(Index <= CachedStop && "Cannot advance to OOB index"); if (Index < CachedStart) // We're already past this index. return; OffsetIntoMapIterator = Index - CachedStart; } const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) { resetCache(); } public: const_iterator() { setToEnd(); } bool operator==(const const_iterator &RHS) const { // Do /not/ compare MapIterator for equality, as this is very expensive. // The cached start/stop values make that check unnecessary. return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) == std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart, RHS.CachedStop); } bool operator!=(const const_iterator &RHS) const { return !operator==(RHS); } IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; } const_iterator &operator++() { // Pre-increment (++It). if (CachedStart + OffsetIntoMapIterator < CachedStop) { // Keep going within the current interval. ++OffsetIntoMapIterator; } else { // We reached the end of the current interval: advance. ++MapIterator; resetCache(); } return *this; } const_iterator operator++(int) { // Post-increment (It++). const_iterator tmp = *this; operator++(); return tmp; } /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If /// no such set bit exists, advance to end(). This is like std::lower_bound. /// This is useful if \p Index is close to the current iterator position. /// However, unlike \ref find(), this has worst-case O(n) performance. void advanceToLowerBound(IndexT Index) { if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) return; // Advance to the first interval containing (or past) Index, or to end(). while (Index > CachedStop) { ++MapIterator; resetCache(); if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) return; } advanceTo(Index); } }; const_iterator begin() const { return const_iterator(Intervals.begin()); } const_iterator end() const { return const_iterator(); } /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index. /// If no such set bit exists, return end(). This is like std::lower_bound. /// This has worst-case logarithmic performance (roughly O(log(gaps between /// contiguous ranges))). const_iterator find(IndexT Index) const { auto UnderlyingIt = Intervals.find(Index); if (UnderlyingIt == Intervals.end()) return end(); auto It = const_iterator(UnderlyingIt); It.advanceTo(Index); return It; } /// Return a range iterator which iterates over all of the set bits in the /// half-open range [Start, End). iterator_range half_open_range(IndexT Start, IndexT End) const { assert(Start < End && "Not a valid range"); auto StartIt = find(Start); if (StartIt == end() || *StartIt >= End) return {end(), end()}; auto EndIt = StartIt; EndIt.advanceToLowerBound(End); return {StartIt, EndIt}; } void print(raw_ostream &OS) const { OS << "{"; for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It) { OS << "[" << It.start(); if (It.start() != It.stop()) OS << ", " << It.stop(); OS << "]"; } OS << "}"; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void dump() const { // LLDB swallows the first line of output after callling dump(). Add // newlines before/after the braces to work around this. dbgs() << "\n"; print(dbgs()); dbgs() << "\n"; } #endif private: void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); } /// Record the overlaps between \p this and \p Other in \p Overlaps. Return /// true if there is any overlap. bool getOverlaps(const ThisT &Other, SmallVectorImpl &Overlaps) const { for (IntervalMapOverlaps I(Intervals, Other.Intervals); I.valid(); ++I) Overlaps.emplace_back(I.start(), I.stop()); assert(llvm::is_sorted(Overlaps, [](IntervalT LHS, IntervalT RHS) { return LHS.second < RHS.first; }) && "Overlaps must be sorted"); return !Overlaps.empty(); } /// Given the set of overlaps between this and some other bitvector, and an /// interval [Start, Stop] from that bitvector, determine the portions of the /// interval which do not overlap with this. void getNonOverlappingParts(IndexT Start, IndexT Stop, const SmallVectorImpl &Overlaps, SmallVectorImpl &NonOverlappingParts) { IndexT NextUncoveredBit = Start; for (IntervalT Overlap : Overlaps) { IndexT OlapStart, OlapStop; std::tie(OlapStart, OlapStop) = Overlap; // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop // and Start <= OlapStop. bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop; if (!DoesOverlap) continue; // Cover the range [NextUncoveredBit, OlapStart). This puts the start of // the next uncovered range at OlapStop+1. if (NextUncoveredBit < OlapStart) NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1); NextUncoveredBit = OlapStop + 1; if (NextUncoveredBit > Stop) break; } if (NextUncoveredBit <= Stop) NonOverlappingParts.emplace_back(NextUncoveredBit, Stop); } Allocator *Alloc; MapT Intervals; }; } // namespace llvm #endif // LLVM_ADT_COALESCINGBITVECTOR_H