Program Listing for File bitset.hpp#

Return to documentation for file (librapid/include/librapid/datastructures/bitset.hpp)

#ifndef LIBRAPID_BITSET_HPP
#define LIBRAPID_BITSET_HPP

namespace librapid {
    template<uint64_t numBits_ = 64, bool stackAlloc_ = true>
    class BitSet {
    public:
        template<uint64_t otherBits, bool otherStackAlloc>

#ifndef LIBRAPID_DOXYGEN
        using BitSetMerger =
          BitSet<(numBits_ > otherBits ? numBits_ : otherBits), stackAlloc_ && otherStackAlloc>;
#else
        using BitSetMerger = BitSet;
#endif

        using ElementType                        = uint64_t;
        static constexpr bool stackAlloc         = stackAlloc_;
        static constexpr uint64_t bitsPerElement = sizeof(ElementType) * 8;
        static constexpr uint64_t numBits        = numBits_;
        static constexpr uint64_t numElements    = (numBits + bitsPerElement - 1) / bitsPerElement;
        using StorageType =
          std::conditional_t<stackAlloc, std::array<ElementType, numElements>, ElementType *>;

        BitSet() { emptyInit(); }

        BitSet(const BitSet &other) {
            init();
            if constexpr (stackAlloc) {
                m_data = other.data();
            } else {
                for (uint64_t i = 0; i < numElements; ++i) { m_data[i] = other.data()[i]; }
            }
        }

        BitSet(BitSet &&other) = default;

        constexpr BitSet(uint64_t value) {
            static_assert(numElements > 0, "Not enough bits in BitSet");
            emptyInit();
            m_data[0] = value;
        }

        constexpr BitSet(const std::string &str, char zero = '0', char one = '1') {
            emptyInit();
            uint64_t stringLength = str.length();
            for (uint64_t i = 0; i < stringLength; ++i) {
                if (str[i] == zero)
                    continue;
                else if (str[i] == one)
                    set(stringLength - i - 1, true);
                else
                    LIBRAPID_ERROR("Invalid character in BitSet string: {}", str[i]);
            }
        }

        BitSet &operator=(const BitSet &other) = default;
        BitSet &operator=(BitSet &&other)      = default;

        ~BitSet() {
            if constexpr (!stackAlloc) delete[] m_data;
        }

        BitSet &set(uint64_t index, bool value) {
            uint64_t element = index / bitsPerElement;
            uint64_t bit     = index % bitsPerElement;
            if (value)
                m_data[element] |= (1ULL << bit);
            else
                m_data[element] &= ~(1ULL << bit);
            return *this;
        }

        BitSet &set(uint64_t start, uint64_t end, bool value) {
            if (start > end) std::swap(start, end);
            uint64_t blockStart = start / bitsPerElement + 1;
            uint64_t blockEnd   = end / bitsPerElement;

            // 1. Set the bits before the blocked elements
            // 2. Set the blocked elements
            // 3. Set the bits after the blocked elements

            for (uint64_t i = start; i < end && i < blockStart * bitsPerElement; ++i) set(i, value);
            for (uint64_t i = blockStart; i < blockEnd; ++i) m_data[i] = value ? ~0ULL : 0ULL;
            for (uint64_t i = blockEnd * bitsPerElement; i < end; ++i) set(i, value);

            return *this;
        }

        LIBRAPID_NODISCARD bool get(uint64_t index) const {
            uint64_t element = index / bitsPerElement;
            uint64_t bit     = index % bitsPerElement;
            return m_data[element] & (1ULL << bit);
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool any() const {
            for (uint64_t i = 0; i < numElements; ++i)
                if (m_data[i]) return true;
            return false;
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool all() const {
            for (uint64_t i = 0; i < numElements; ++i)
                if (m_data[i] != ~(ElementType(0))) return false;
            return true;
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool none() const {
            for (uint64_t i = 0; i < numElements; ++i)
                if (m_data[i]) return false;
            return true;
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE uint64_t first() const {
            for (uint64_t i = 0; i < numElements; ++i) {
                if (m_data[i]) {
#if defined(LIBRAPID_CLANG) || defined(LIBRAPID_GCC)
                    if constexpr (sizeof(ElementType) == 8) {
                        return __builtin_ctzll(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 4) {
                        return __builtin_ctz(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 2) {
                        return __builtin_ctzs(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 1) {
                        return __builtin_ctz(m_data[i]) + i * bitsPerElement;
                    }
#elif defined(LIBRAPID_MSVC)
                    unsigned long index;
                    if constexpr (sizeof(ElementType) == 8) {
                        _BitScanForward64(&index, m_data[i]);
                    } else {
                        _BitScanForward(&index, m_data[i]);
                    }
                    return index + i * bitsPerElement;
#else
                    for (uint64_t j = 0; j < bitsPerElement; ++j) {
                        if (m_data[i] & (1ULL << j)) return i * bitsPerElement + j;
                    }
#endif
                }
            }
            return numBits;
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE uint64_t last() const {
            for (uint64_t i = numElements - 1; i >= 0; --i) {
                if (m_data[i]) {
#if defined(LIBRAPID_CLANG) || defined(LIBRAPID_GCC)
                    if constexpr (sizeof(ElementType) == 8) {
                        return 63 - __builtin_clzll(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 4) {
                        return 31 - __builtin_clz(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 2) {
                        return 15 - __builtin_clzs(m_data[i]) + i * bitsPerElement;
                    } else if constexpr (sizeof(ElementType) == 1) {
                        return 7 - __builtin_clz(m_data[i]) + i * bitsPerElement;
                    }
#elif defined(LIBRAPID_MSVC)
                    unsigned long index;
                    if constexpr (sizeof(ElementType) == 8) {
                        _BitScanReverse64(&index, m_data[i]);
                    } else {
                        _BitScanReverse(&index, m_data[i]);
                    }
                    return index + i * bitsPerElement;
#else
                    for (uint64_t j = bitsPerElement - 1; j >= 0; --j) {
                        if (m_data[i] & (1ULL << j)) return i * bitsPerElement + j;
                    }
#endif
                }
            }
            return numBits;
        }

        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE uint64_t popCount() const {
            uint64_t res = 0;
            for (uint64_t i = 0; i < numElements; ++i) { res += std::popcount(m_data[i]); }
            return res;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        BitSet &operator|=(const BitSet<otherBits, otherStackAlloc> &other) {
            constexpr uint64_t otherElements = std::decay_t<decltype(other)>::numElements;

            uint64_t index = 0;
            while (index < min(numElements, otherElements)) {
                m_data[index] |= other.data()[index];
                ++index;
            }

            return *this;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        BitSet &operator&=(const BitSet<otherBits, otherStackAlloc> &other) {
            // If otherBits < numBits, zero out the top of this
            // Otherwise, make sure we don't read out of bounds

            constexpr uint64_t otherElements = std::decay_t<decltype(other)>::numElements;

            uint64_t index = 0;
            while (index < min(numElements, otherElements)) {
                m_data[index] &= other.data()[index];
                ++index;
            }

            while (index < numElements) {
                m_data[index] = 0;
                ++index;
            }

            return *this;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        BitSet &operator^=(const BitSet<otherBits, otherStackAlloc> &other) {
            constexpr uint64_t otherElements = std::decay_t<decltype(other)>::numElements;

            uint64_t index = 0;
            while (index < min(numElements, otherElements)) {
                m_data[index] ^= other.data()[index];
                ++index;
            }

            return *this;
        }

        BitSet &operator<<=(int64_t shift) {
            if (shift < 0) return *this >>= -shift;
            if (shift == 0) return *this;

            int64_t elementShift = shift / bitsPerElement;
            int64_t digitShift   = shift % bitsPerElement;

            // Clear high bits
            for (int64_t i = numElements - elementShift; i < numElements; ++i) { m_data[i] = 0; }

            // Shift elements
            for (int64_t i = numElements - 1; i >= elementShift; --i) {
                ElementType tmp = 0;
                tmp |= m_data[i - elementShift] << digitShift;

                if (i - elementShift - 1 >= 0) {
                    tmp |= m_data[i - elementShift - 1] >> (bitsPerElement - digitShift);
                }

                m_data[i] = tmp;
            }

            // Clear low bits
            for (int64_t i = 0; i < elementShift; ++i) { m_data[i] = 0; }

            return *this;
        }

        BitSet &operator>>=(int64_t shift) {
            if (shift < 0) return *this <<= -shift; // Handle negative shift
            if (shift == 0) return *this;           // Handle zero shift

            int64_t elementShift = shift / bitsPerElement;
            int64_t digitShift   = shift % bitsPerElement;

            // Clear low bits
            for (int64_t i = 0; i < elementShift; ++i) { m_data[i] = 0; }

            // Shift elements
            for (int64_t i = 0; i < numElements - elementShift; ++i) {
                ElementType tmp = 0;
                tmp |= m_data[i + elementShift] >> digitShift;

                if (i + elementShift + 1 < numElements) {
                    tmp |= m_data[i + elementShift + 1] << (bitsPerElement - digitShift);
                }

                m_data[i] = tmp;
            }

            // Clear high bits
            for (int64_t i = numElements - elementShift; i < numElements; ++i) { m_data[i] = 0; }

            return *this;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        auto operator|(const BitSet<otherBits, otherStackAlloc> &other) const
          -> BitSetMerger<otherBits, otherStackAlloc> {
            BitSetMerger<otherBits, otherStackAlloc> res = *this;
            res |= other;
            return res;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        auto operator&(const BitSet<otherBits, otherStackAlloc> &other) const
          -> BitSetMerger<otherBits, otherStackAlloc> {
            BitSetMerger<otherBits, otherStackAlloc> res = *this;
            res &= other;
            return res;
        }

        template<uint64_t otherBits = numBits, bool otherStackAlloc = stackAlloc>
        auto operator^(const BitSet<otherBits, otherStackAlloc> &other) const
          -> BitSetMerger<otherBits, otherStackAlloc> {
            BitSetMerger<otherBits, otherStackAlloc> res = *this;
            res ^= other;
            return res;
        }

        BitSet operator<<(int64_t shift) const {
            BitSet res = *this;
            res <<= shift;
            return res;
        }

        BitSet operator>>(int64_t shift) const {
            BitSet res = *this;
            res >>= shift;
            return res;
        }

        BitSet operator~() const {
            BitSet res = *this;

            for (uint64_t i = 0; i < numElements; ++i) { res.m_data[i] = ~res.m_data[i]; }
            res.m_data[numElements - 1] &= highMask();

            return res;
        }

        bool operator==(const BitSet &other) const {
            for (uint64_t i = 0; i < numElements; ++i) {
                if (m_data[i] != other.data()[i]) return false;
            }
            return true;
        }

        const auto &data() const { return m_data; }
        auto &data() { return m_data; }

        template<typename Integer = ElementType>
            requires(std::is_integral_v<Integer>)
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int toInt() const {
#if defined(LIBRAPID_DEBUG)
            static bool warned = false;
            if (!warned && last() >= sizeof(Integer) * 8) {
                warned = true;
                LIBRAPID_WARN(
                  "BitSet::toInt() called with BitSet that is too large for the "
                  "requested integer type");
            }
#endif

            return m_data[0];
        }

        template<typename T, typename Char, typename Ctx>
        void str(const fmt::formatter<T, Char> &format, Ctx &ctx) const {
            for (int64_t i = numBits - 1; i >= 0; --i) { format.format(get(i), ctx); }
        }

    protected:
        constexpr uint64_t highMask() const {
            ElementType res = 0;
            for (uint64_t i = 0; i < numBits % bitsPerElement; ++i) { res |= 1ULL << i; }
            return res;
        }

        void zero() {
            for (uint64_t i = 0; i < numElements; ++i) { m_data[i] = 0; }
        }

        void init() {
            if constexpr (!stackAlloc) { m_data = new ElementType[numElements]; }
        }

        void emptyInit() {
            init();
            zero();
        }

    private:
        StorageType m_data;
    };

    template<typename T>
    uint64_t popCount(const T &value) {
        return std::popcount(value);
    }

    template<uint64_t numBits, bool stackAlloc>
    uint64_t popCount(const BitSet<numBits, stackAlloc> &bitset) {
        uint64_t res = 0;
        for (uint64_t i = 0; i < BitSet<numBits, stackAlloc>::numElements; ++i) {
            res += popCount(bitset.data()[i]);
        }
        return res;
    }
} // namespace librapid

template<uint64_t numBits, bool stackAlloc, typename Char>
struct fmt::formatter<librapid::BitSet<numBits, stackAlloc>, Char> {
private:
    using Type = librapid::BitSet<numBits, stackAlloc>;
    using Base = fmt::formatter<int, Char>;
    Base m_base;

public:
    template<typename ParseContext>
    FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const Char * {
        return m_base.parse(ctx);
    }

    template<typename FormatContext>
    FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
        val.str(m_base, ctx);
        return ctx.out();
    }
};

LIBRAPID_SIMPLE_IO_NORANGE(uint64_t numBits COMMA bool stackAlloc,
                           librapid::BitSet<numBits COMMA stackAlloc>)

// std ostream support
template<uint64_t numElements, bool stackAlloc>
std::ostream &operator<<(std::ostream &os,
                         const librapid::BitSet<numElements, stackAlloc> &bitset) {
    return os << fmt::format("{}", bitset);
}

#endif // LIBRAPID_BITSET_HPP