Program Listing for File set.hpp#
↰ Return to documentation for file (librapid/include/librapid/datastructures/set.hpp)
#ifndef LIBRAPID_SET_HPP
#define LIBRAPID_SET_HPP
namespace librapid {
template<typename ElementType_>
class Set {
public:
using ElementType = ElementType_;
using VectorType = std::vector<ElementType>;
using VectorIterator = typename VectorType::iterator;
using VectorConstIterator = typename VectorType::const_iterator;
Set() = default;
Set(const Set &other) = default;
Set(Set &&other) = default;
template<typename ShapeType, typename StorageType>
Set(const array::ArrayContainer<ShapeType, StorageType> &arr) {
reserve(arr.size());
for (size_t i = 0; i < arr.size(); ++i) { pushBack(arr.storage()[i]); }
sort();
prune();
}
Set(const std::vector<ElementType> &data) : m_data(data) {
sort();
prune();
}
Set(const std::initializer_list<ElementType> &data) : m_data(data) {
sort();
prune();
}
Set &operator=(const Set &other) = default;
Set &operator=(Set &&other) = default;
LIBRAPID_NODISCARD int64_t size() const { return m_data.size(); }
LIBRAPID_NODISCARD const ElementType &operator[](int64_t index) const {
LIBRAPID_ASSERT(index >= 0 && index < m_data.size(),
"Index out of bounds: {} (size: {})",
index,
m_data.size());
return m_data[index];
}
LIBRAPID_NODISCARD bool contains(const ElementType &val) const {
// Binary search
int64_t mid;
int64_t head = 0;
int64_t tail = m_data.size() - 1;
bool found = false;
while (!found && head <= tail) {
mid = (head + tail) / 2;
if (val < m_data[mid]) {
tail = mid - 1;
} else if (val > m_data[mid]) {
head = mid + 1;
} else {
found = true;
}
}
return found;
}
Set &insert(const ElementType &val) {
if (contains(val)) return *this;
// Insert into data
for (auto it = m_data.begin(); it < m_data.end(); it++) {
if (*it > val) {
m_data.insert(it, val);
return *this;
}
}
m_data.emplace_back(val);
return *this;
}
Set &insert(const std::vector<ElementType> &data) {
for (const auto &val : data) { insert(val); }
return *this;
}
Set &insert(const std::initializer_list<ElementType> &data) {
for (const auto &val : data) { insert(val); }
return *this;
}
Set &operator+=(const ElementType &val) { return insert(val); }
Set &operator+=(const std::vector<ElementType> &data) { return insert(data); }
Set &operator+=(const std::initializer_list<ElementType> &data) { return insert(data); }
Set operator+(const ElementType &val) {
Set result = *this;
return result += val;
}
Set operator+(const std::vector<ElementType> &data) {
Set result = *this;
return result += data;
}
Set operator+(const std::initializer_list<ElementType> &data) {
Set result = *this;
return result += data;
}
Set &discard(const ElementType &val) {
// Binary search
int64_t mid;
int64_t head = 0;
int64_t tail = m_data.size() - 1;
while (head <= tail) {
mid = (head + tail) / 2;
if (val < m_data[mid]) {
tail = mid - 1;
} else if (val > m_data[mid]) {
head = mid + 1;
} else {
m_data.erase(m_data.begin() + mid);
return *this;
}
}
return *this;
}
Set &discard(const std::vector<ElementType> &data) {
for (const auto &val : data) { discard(val); }
return *this;
}
Set &discard(const std::initializer_list<ElementType> &data) {
for (const auto &val : data) { discard(val); }
return *this;
}
Set &remove(const ElementType &val) {
LIBRAPID_ASSERT(contains(val), "Set does not contain value: {}", val);
return discard(val);
}
Set &remove(const std::vector<ElementType> &data) {
for (const auto &val : data) { remove(val); }
return *this;
}
Set &remove(const std::initializer_list<ElementType> &data) {
for (const auto &val : data) { remove(val); }
return *this;
}
Set &operator-=(const ElementType &val) { return discard(val); }
Set &operator-=(const std::vector<ElementType> &data) { return discard(data); }
Set &operator-=(const std::initializer_list<ElementType> &data) { return discard(data); }
Set operator-(const ElementType &val) {
Set result = *this;
return result -= val;
}
Set operator-(const std::vector<ElementType> &data) {
Set result = *this;
return result -= data;
}
Set operator-(const std::initializer_list<ElementType> &data) {
Set result = *this;
return result -= data;
}
LIBRAPID_NODISCARD Set operator|(const Set &other) const {
// Union operator
Set result;
// Reserve space for elements (a decent guess)
result.reserve(m_data.size() + other.m_data.size());
int64_t indexA = 0;
int64_t indexB = 0;
while (indexA < m_data.size() && indexB < other.m_data.size()) {
if (m_data[indexA] < other.m_data[indexB]) {
result.pushBack(m_data[indexA]);
++indexA;
} else if (m_data[indexA] > other.m_data[indexB]) {
result.pushBack(other.m_data[indexB]);
++indexB;
} else {
result.pushBack(m_data[indexA]);
++indexA;
++indexB;
}
}
// Add remaining elements
result.insert(result.end(), m_data.begin() + indexA, m_data.end());
result.insert(result.end(), other.m_data.begin() + indexB, other.m_data.end());
return result;
}
Set operator&(const Set &other) const {
// Intersection operator
Set result;
// Reserve space for elements (a decent guess)
result.reserve(std::min(m_data.size(), other.m_data.size()));
int64_t indexA = 0;
int64_t indexB = 0;
while (indexA < m_data.size() && indexB < other.m_data.size()) {
if (m_data[indexA] < other.m_data[indexB]) {
++indexA;
} else if (m_data[indexA] > other.m_data[indexB]) {
++indexB;
} else {
result.pushBack(m_data[indexA]);
++indexA;
++indexB;
}
}
return result;
}
Set operator^(const Set &other) const {
// Symmetric difference operator (elements in either set but not both)
Set result;
// Reserve space for elements (a decent guess)
result.reserve(m_data.size() + other.m_data.size());
int64_t indexA = 0;
int64_t indexB = 0;
while (indexA < m_data.size() && indexB < other.m_data.size()) {
if (m_data[indexA] < other.m_data[indexB]) {
result.pushBack(m_data[indexA]);
++indexA;
} else if (m_data[indexA] > other.m_data[indexB]) {
result.pushBack(other.m_data[indexB]);
++indexB;
} else {
++indexA;
++indexB;
}
}
return result;
}
Set operator-(const Set &other) const {
// Set difference
Set result;
// Reserve space for elements (a decent guess)
result.reserve(m_data.size() + other.m_data.size());
int64_t indexA = 0;
int64_t indexB = 0;
while (indexA < m_data.size() && indexB < other.m_data.size()) {
if (m_data[indexA] < other.m_data[indexB]) {
result.pushBack(m_data[indexA]);
++indexA;
} else if (m_data[indexA] > other.m_data[indexB]) {
++indexB;
} else {
++indexA;
++indexB;
}
}
// Add remaining elements
result.insert(result.end(), m_data.begin() + indexA, m_data.end());
return result;
}
LIBRAPID_NODISCARD auto operator<=>(const Set &other) const = default;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto begin() const { return m_data.begin(); }
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto end() const { return m_data.end(); }
template<typename T, typename Char, typename Ctx>
void str(const fmt::formatter<T, Char> &format, Ctx &ctx) const {
fmt::format_to(ctx.out(), "(");
for (int64_t i = 0; i < m_data.size(); ++i) {
format.format(m_data[i], ctx);
if (i < m_data.size() - 1) { fmt::format_to(ctx.out(), ", "); }
}
fmt::format_to(ctx.out(), ")");
}
protected:
void reserve(size_t elements) { m_data.reserve(elements); }
void sort() { std::sort(m_data.begin(), m_data.end()); }
void prune() {
std::vector<ElementType> newData;
newData.reserve(m_data.size());
for (int64_t i = 0; i < m_data.size(); ++i) {
if (i == 0 || m_data[i] != m_data[i - 1]) { newData.emplace_back(m_data[i]); }
}
m_data = newData;
}
LIBRAPID_ALWAYS_INLINE void pushBack(const ElementType &val) { m_data.emplace_back(val); }
LIBRAPID_ALWAYS_INLINE void insert(VectorConstIterator insertLocation,
VectorConstIterator begin, VectorConstIterator end) {
m_data.insert(insertLocation, begin, end);
}
private:
std::vector<ElementType> m_data;
};
} // namespace librapid
template<typename ElementType, typename Char>
struct fmt::formatter<librapid::Set<ElementType>, Char> {
private:
using Type = librapid::Set<ElementType>;
using Base = fmt::formatter<ElementType, Char>;
Base m_base;
public:
template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const Char * {
return m_base.parse(ctx);
}
template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_base, ctx);
return ctx.out();
}
};
LIBRAPID_SIMPLE_IO_NORANGE(typename ElementType, librapid::Set<ElementType>)
// std ostream support
template<typename ElementType>
std::ostream &operator<<(std::ostream &os, const librapid::Set<ElementType> &set) {
return os << fmt::format("{}", set);
}
#endif // LIBRAPID_SET_HPP