Program Listing for File set.hpp#

Return to documentation for file (librapid/include/librapid/datastructures/set.hpp)

#ifndef LIBRAPID_SET_HPP
#define LIBRAPID_SET_HPP

namespace librapid {
    template<typename ElementType_>
    class Set {
    public:
        using ElementType = ElementType_;

        using VectorType = std::vector<ElementType>;

        using VectorIterator = typename VectorType::iterator;

        using VectorConstIterator = typename VectorType::const_iterator;

        Set() = default;

        Set(const Set &other) = default;

        Set(Set &&other) = default;

        template<typename ShapeType, typename StorageType>
        Set(const array::ArrayContainer<ShapeType, StorageType> &arr) {
            reserve(arr.size());
            for (size_t i = 0; i < arr.size(); ++i) { pushBack(arr.storage()[i]); }
            sort();
            prune();
        }

        Set(const std::vector<ElementType> &data) : m_data(data) {
            sort();
            prune();
        }

        Set(const std::initializer_list<ElementType> &data) : m_data(data) {
            sort();
            prune();
        }

        Set &operator=(const Set &other) = default;

        Set &operator=(Set &&other) = default;

        LIBRAPID_NODISCARD int64_t size() const { return m_data.size(); }

        LIBRAPID_NODISCARD const ElementType &operator[](int64_t index) const {
            LIBRAPID_ASSERT(index >= 0 && index < m_data.size(),
                            "Index out of bounds: {} (size: {})",
                            index,
                            m_data.size());
            return m_data[index];
        }

        LIBRAPID_NODISCARD bool contains(const ElementType &val) const {
            // Binary search

            int64_t mid;
            int64_t head = 0;
            int64_t tail = m_data.size() - 1;
            bool found   = false;

            while (!found && head <= tail) {
                mid = (head + tail) / 2;
                if (val < m_data[mid]) {
                    tail = mid - 1;
                } else if (val > m_data[mid]) {
                    head = mid + 1;
                } else {
                    found = true;
                }
            }

            return found;
        }

        Set &insert(const ElementType &val) {
            if (contains(val)) return *this;

            // Insert into data
            for (auto it = m_data.begin(); it < m_data.end(); it++) {
                if (*it > val) {
                    m_data.insert(it, val);
                    return *this;
                }
            }

            m_data.emplace_back(val);
            return *this;
        }

        Set &insert(const std::vector<ElementType> &data) {
            for (const auto &val : data) { insert(val); }
            return *this;
        }

        Set &insert(const std::initializer_list<ElementType> &data) {
            for (const auto &val : data) { insert(val); }
            return *this;
        }

        Set &operator+=(const ElementType &val) { return insert(val); }

        Set &operator+=(const std::vector<ElementType> &data) { return insert(data); }

        Set &operator+=(const std::initializer_list<ElementType> &data) { return insert(data); }

        Set operator+(const ElementType &val) {
            Set result = *this;
            return result += val;
        }

        Set operator+(const std::vector<ElementType> &data) {
            Set result = *this;
            return result += data;
        }

        Set operator+(const std::initializer_list<ElementType> &data) {
            Set result = *this;
            return result += data;
        }

        Set &discard(const ElementType &val) {
            // Binary search

            int64_t mid;
            int64_t head = 0;
            int64_t tail = m_data.size() - 1;

            while (head <= tail) {
                mid = (head + tail) / 2;
                if (val < m_data[mid]) {
                    tail = mid - 1;
                } else if (val > m_data[mid]) {
                    head = mid + 1;
                } else {
                    m_data.erase(m_data.begin() + mid);
                    return *this;
                }
            }

            return *this;
        }

        Set &discard(const std::vector<ElementType> &data) {
            for (const auto &val : data) { discard(val); }
            return *this;
        }

        Set &discard(const std::initializer_list<ElementType> &data) {
            for (const auto &val : data) { discard(val); }
            return *this;
        }

        Set &remove(const ElementType &val) {
            LIBRAPID_ASSERT(contains(val), "Set does not contain value: {}", val);
            return discard(val);
        }

        Set &remove(const std::vector<ElementType> &data) {
            for (const auto &val : data) { remove(val); }
            return *this;
        }

        Set &remove(const std::initializer_list<ElementType> &data) {
            for (const auto &val : data) { remove(val); }
            return *this;
        }

        Set &operator-=(const ElementType &val) { return discard(val); }

        Set &operator-=(const std::vector<ElementType> &data) { return discard(data); }

        Set &operator-=(const std::initializer_list<ElementType> &data) { return discard(data); }

        Set operator-(const ElementType &val) {
            Set result = *this;
            return result -= val;
        }

        Set operator-(const std::vector<ElementType> &data) {
            Set result = *this;
            return result -= data;
        }

        Set operator-(const std::initializer_list<ElementType> &data) {
            Set result = *this;
            return result -= data;
        }

        LIBRAPID_NODISCARD Set operator|(const Set &other) const {
            // Union operator
            Set result;

            // Reserve space for elements (a decent guess)
            result.reserve(m_data.size() + other.m_data.size());

            int64_t indexA = 0;
            int64_t indexB = 0;

            while (indexA < m_data.size() && indexB < other.m_data.size()) {
                if (m_data[indexA] < other.m_data[indexB]) {
                    result.pushBack(m_data[indexA]);
                    ++indexA;
                } else if (m_data[indexA] > other.m_data[indexB]) {
                    result.pushBack(other.m_data[indexB]);
                    ++indexB;
                } else {
                    result.pushBack(m_data[indexA]);
                    ++indexA;
                    ++indexB;
                }
            }

            // Add remaining elements
            result.insert(result.end(), m_data.begin() + indexA, m_data.end());
            result.insert(result.end(), other.m_data.begin() + indexB, other.m_data.end());

            return result;
        }

        Set operator&(const Set &other) const {
            // Intersection operator
            Set result;

            // Reserve space for elements (a decent guess)
            result.reserve(std::min(m_data.size(), other.m_data.size()));

            int64_t indexA = 0;
            int64_t indexB = 0;

            while (indexA < m_data.size() && indexB < other.m_data.size()) {
                if (m_data[indexA] < other.m_data[indexB]) {
                    ++indexA;
                } else if (m_data[indexA] > other.m_data[indexB]) {
                    ++indexB;
                } else {
                    result.pushBack(m_data[indexA]);
                    ++indexA;
                    ++indexB;
                }
            }

            return result;
        }

        Set operator^(const Set &other) const {
            // Symmetric difference operator (elements in either set but not both)
            Set result;

            // Reserve space for elements (a decent guess)
            result.reserve(m_data.size() + other.m_data.size());

            int64_t indexA = 0;
            int64_t indexB = 0;

            while (indexA < m_data.size() && indexB < other.m_data.size()) {
                if (m_data[indexA] < other.m_data[indexB]) {
                    result.pushBack(m_data[indexA]);
                    ++indexA;
                } else if (m_data[indexA] > other.m_data[indexB]) {
                    result.pushBack(other.m_data[indexB]);
                    ++indexB;
                } else {
                    ++indexA;
                    ++indexB;
                }
            }

            return result;
        }

        Set operator-(const Set &other) const {
            // Set difference
            Set result;

            // Reserve space for elements (a decent guess)
            result.reserve(m_data.size() + other.m_data.size());

            int64_t indexA = 0;
            int64_t indexB = 0;

            while (indexA < m_data.size() && indexB < other.m_data.size()) {
                if (m_data[indexA] < other.m_data[indexB]) {
                    result.pushBack(m_data[indexA]);
                    ++indexA;
                } else if (m_data[indexA] > other.m_data[indexB]) {
                    ++indexB;
                } else {
                    ++indexA;
                    ++indexB;
                }
            }

            // Add remaining elements
            result.insert(result.end(), m_data.begin() + indexA, m_data.end());

            return result;
        }

        LIBRAPID_NODISCARD auto operator<=>(const Set &other) const = default;

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto begin() const { return m_data.begin(); }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto end() const { return m_data.end(); }

        template<typename T, typename Char, typename Ctx>
        void str(const fmt::formatter<T, Char> &format, Ctx &ctx) const {
            fmt::format_to(ctx.out(), "(");
            for (int64_t i = 0; i < m_data.size(); ++i) {
                format.format(m_data[i], ctx);
                if (i < m_data.size() - 1) { fmt::format_to(ctx.out(), ", "); }
            }
            fmt::format_to(ctx.out(), ")");
        }

    protected:
        void reserve(size_t elements) { m_data.reserve(elements); }

        void sort() { std::sort(m_data.begin(), m_data.end()); }

        void prune() {
            std::vector<ElementType> newData;
            newData.reserve(m_data.size());
            for (int64_t i = 0; i < m_data.size(); ++i) {
                if (i == 0 || m_data[i] != m_data[i - 1]) { newData.emplace_back(m_data[i]); }
            }
            m_data = newData;
        }

        LIBRAPID_ALWAYS_INLINE void pushBack(const ElementType &val) { m_data.emplace_back(val); }

        LIBRAPID_ALWAYS_INLINE void insert(VectorConstIterator insertLocation,
                                           VectorConstIterator begin, VectorConstIterator end) {
            m_data.insert(insertLocation, begin, end);
        }

    private:
        std::vector<ElementType> m_data;
    };
} // namespace librapid

template<typename ElementType, typename Char>
struct fmt::formatter<librapid::Set<ElementType>, Char> {
private:
    using Type = librapid::Set<ElementType>;
    using Base = fmt::formatter<ElementType, Char>;
    Base m_base;

public:
    template<typename ParseContext>
    FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const Char * {
        return m_base.parse(ctx);
    }

    template<typename FormatContext>
    FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
        val.str(m_base, ctx);
        return ctx.out();
    }
};

LIBRAPID_SIMPLE_IO_NORANGE(typename ElementType, librapid::Set<ElementType>)

// std ostream support
template<typename ElementType>
std::ostream &operator<<(std::ostream &os, const librapid::Set<ElementType> &set) {
    return os << fmt::format("{}", set);
}

#endif // LIBRAPID_SET_HPP