From d59032fedcfe9c7991cabc6154c271200bcbd526 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 14 Apr 2020 14:53:07 -0700 Subject: [PATCH] [llvm][STLExtras] Move various iterator/range utilities from MLIR to LLVM This revision moves the various range utilities present in MLIR to LLVM to enable greater reuse. This revision moves the following utilities: * indexed_accessor_* This is set of utility iterator/range base classes that allow for building a range class where the iterators are represented by an object+index pair. * make_second_range Given a range of pairs, returns a range iterating over the `second` elements. * hasSingleElement Returns if the given range has 1 element. size() == 1 checks end up being very common, but size() is not always O(1) (e.g., ilist). This method provides O(1) checks for those cases. Differential Revision: https://reviews.llvm.org/D78064 --- include/llvm/ADT/STLExtras.h | 213 ++++++++++++++++++++++ unittests/Support/CMakeLists.txt | 1 + unittests/Support/IndexedAccessorTest.cpp | 49 +++++ 3 files changed, 263 insertions(+) create mode 100644 unittests/Support/IndexedAccessorTest.cpp diff --git a/include/llvm/ADT/STLExtras.h b/include/llvm/ADT/STLExtras.h index 53457f02b88..ff53a3c7a65 100644 --- a/include/llvm/ADT/STLExtras.h +++ b/include/llvm/ADT/STLExtras.h @@ -263,6 +263,12 @@ constexpr bool empty(const T &RangeOrContainer) { return adl_begin(RangeOrContainer) == adl_end(RangeOrContainer); } +/// Returns true of the given range only contains a single element. +template bool hasSingleElement(ContainerTy &&c) { + auto it = std::begin(c), e = std::end(c); + return it != e && std::next(it) == e; +} + /// Return a range covering \p RangeOrContainer with the first N elements /// excluded. template auto drop_begin(T &&RangeOrContainer, size_t N) { @@ -1017,6 +1023,213 @@ detail::concat_range concat(RangeTs &&... Ranges) { std::forward(Ranges)...); } +/// A utility class used to implement an iterator that contains some base object +/// and an index. The iterator moves the index but keeps the base constant. +template +class indexed_accessor_iterator + : public llvm::iterator_facade_base { +public: + ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const { + assert(base == rhs.base && "incompatible iterators"); + return index - rhs.index; + } + bool operator==(const indexed_accessor_iterator &rhs) const { + return base == rhs.base && index == rhs.index; + } + bool operator<(const indexed_accessor_iterator &rhs) const { + assert(base == rhs.base && "incompatible iterators"); + return index < rhs.index; + } + + DerivedT &operator+=(ptrdiff_t offset) { + this->index += offset; + return static_cast(*this); + } + DerivedT &operator-=(ptrdiff_t offset) { + this->index -= offset; + return static_cast(*this); + } + + /// Returns the current index of the iterator. + ptrdiff_t getIndex() const { return index; } + + /// Returns the current base of the iterator. + const BaseT &getBase() const { return base; } + +protected: + indexed_accessor_iterator(BaseT base, ptrdiff_t index) + : base(base), index(index) {} + BaseT base; + ptrdiff_t index; +}; + +namespace detail { +/// The class represents the base of a range of indexed_accessor_iterators. It +/// provides support for many different range functionalities, e.g. +/// drop_front/slice/etc.. Derived range classes must implement the following +/// static methods: +/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) +/// - Dereference an iterator pointing to the base object at the given +/// index. +/// * BaseT offset_base(const BaseT &base, ptrdiff_t index) +/// - Return a new base that is offset from the provide base by 'index' +/// elements. +template +class indexed_accessor_range_base { +public: + using RangeBaseT = + indexed_accessor_range_base; + + /// An iterator element of this range. + class iterator : public indexed_accessor_iterator { + public: + // Index into this iterator, invoking a static method on the derived type. + ReferenceT operator*() const { + return DerivedT::dereference_iterator(this->getBase(), this->getIndex()); + } + + private: + iterator(BaseT owner, ptrdiff_t curIndex) + : indexed_accessor_iterator( + owner, curIndex) {} + + /// Allow access to the constructor. + friend indexed_accessor_range_base; + }; + + indexed_accessor_range_base(iterator begin, iterator end) + : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())), + count(end.getIndex() - begin.getIndex()) {} + indexed_accessor_range_base(const iterator_range &range) + : indexed_accessor_range_base(range.begin(), range.end()) {} + indexed_accessor_range_base(BaseT base, ptrdiff_t count) + : base(base), count(count) {} + + iterator begin() const { return iterator(base, 0); } + iterator end() const { return iterator(base, count); } + ReferenceT operator[](unsigned index) const { + assert(index < size() && "invalid index for value range"); + return DerivedT::dereference_iterator(base, index); + } + + /// Compare this range with another. + template bool operator==(const OtherT &other) { + return size() == std::distance(other.begin(), other.end()) && + std::equal(begin(), end(), other.begin()); + } + + /// Return the size of this range. + size_t size() const { return count; } + + /// Return if the range is empty. + bool empty() const { return size() == 0; } + + /// Drop the first N elements, and keep M elements. + DerivedT slice(size_t n, size_t m) const { + assert(n + m <= size() && "invalid size specifiers"); + return DerivedT(DerivedT::offset_base(base, n), m); + } + + /// Drop the first n elements. + DerivedT drop_front(size_t n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return slice(n, size() - n); + } + /// Drop the last n elements. + DerivedT drop_back(size_t n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return DerivedT(base, size() - n); + } + + /// Take the first n elements. + DerivedT take_front(size_t n = 1) const { + return n < size() ? drop_back(size() - n) + : static_cast(*this); + } + + /// Take the last n elements. + DerivedT take_back(size_t n = 1) const { + return n < size() ? drop_front(size() - n) + : static_cast(*this); + } + + /// Allow conversion to any type accepting an iterator_range. + template >::value>> + operator RangeT() const { + return RangeT(iterator_range(*this)); + } + +protected: + indexed_accessor_range_base(const indexed_accessor_range_base &) = default; + indexed_accessor_range_base(indexed_accessor_range_base &&) = default; + indexed_accessor_range_base & + operator=(const indexed_accessor_range_base &) = default; + + /// The base that owns the provided range of values. + BaseT base; + /// The size from the owning range. + ptrdiff_t count; +}; +} // end namespace detail + +/// This class provides an implementation of a range of +/// indexed_accessor_iterators where the base is not indexable. Ranges with +/// bases that are offsetable should derive from indexed_accessor_range_base +/// instead. Derived range classes are expected to implement the following +/// static method: +/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index) +/// - Dereference an iterator pointing to a parent base at the given index. +template +class indexed_accessor_range + : public detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT> { +public: + indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) + : detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT>( + std::make_pair(base, startIndex), count) {} + using detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, + ReferenceT>::indexed_accessor_range_base; + + /// Returns the current base of the range. + const BaseT &getBase() const { return this->base.first; } + + /// Returns the current start index of the range. + ptrdiff_t getStartIndex() const { return this->base.second; } + + /// See `detail::indexed_accessor_range_base` for details. + static std::pair + offset_base(const std::pair &base, ptrdiff_t index) { + // We encode the internal base as a pair of the derived base and a start + // index into the derived base. + return std::make_pair(base.first, base.second + index); + } + /// See `detail::indexed_accessor_range_base` for details. + static ReferenceT + dereference_iterator(const std::pair &base, + ptrdiff_t index) { + return DerivedT::dereference(base.first, base.second + index); + } +}; + +/// Given a container of pairs, return a range over the second elements. +template auto make_second_range(ContainerTy &&c) { + return llvm::map_range( + std::forward(c), + [](decltype((*std::begin(c))) elt) -> decltype((elt.second)) { + return elt.second; + }); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// diff --git a/unittests/Support/CMakeLists.txt b/unittests/Support/CMakeLists.txt index 0c321133383..b9eeba165c9 100644 --- a/unittests/Support/CMakeLists.txt +++ b/unittests/Support/CMakeLists.txt @@ -40,6 +40,7 @@ add_llvm_unittest(SupportTests FormatVariadicTest.cpp GlobPatternTest.cpp Host.cpp + IndexedAccessorTest.cpp ItaniumManglingCanonicalizerTest.cpp JSONTest.cpp KnownBitsTest.cpp diff --git a/unittests/Support/IndexedAccessorTest.cpp b/unittests/Support/IndexedAccessorTest.cpp new file mode 100644 index 00000000000..9981e91df10 --- /dev/null +++ b/unittests/Support/IndexedAccessorTest.cpp @@ -0,0 +1,49 @@ +//===- IndexedAccessorTest.cpp - Indexed Accessor Tests -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "gmock/gmock.h" + +using namespace llvm; +using namespace llvm::detail; + +namespace { +/// Simple indexed accessor range that wraps an array. +template +struct ArrayIndexedAccessorRange + : public indexed_accessor_range, T *, T> { + ArrayIndexedAccessorRange(T *data, ptrdiff_t start, ptrdiff_t numElements) + : indexed_accessor_range, T *, T>( + data, start, numElements) {} + using indexed_accessor_range, T *, + T>::indexed_accessor_range; + + /// See `llvm::indexed_accessor_range` for details. + static T &dereference(T *data, ptrdiff_t index) { return data[index]; } +}; +} // end anonymous namespace + +template +static void compareData(ArrayIndexedAccessorRange range, + ArrayRef referenceData) { + ASSERT_TRUE(referenceData.size() == range.size()); + ASSERT_TRUE(std::equal(range.begin(), range.end(), referenceData.begin())); +} + +namespace { +TEST(AccessorRange, SliceTest) { + int rawData[] = {0, 1, 2, 3, 4}; + ArrayRef data = llvm::makeArrayRef(rawData); + + ArrayIndexedAccessorRange range(rawData, /*start=*/0, /*numElements=*/5); + compareData(range, data); + compareData(range.slice(2, 3), data.slice(2, 3)); + compareData(range.slice(0, 5), data.slice(0, 5)); +} +} // end anonymous namespace