Program Listing for File shape.hpp#
↰ Return to documentation for file (librapid/include/librapid/array/shape.hpp)
#ifndef LIBRAPID_ARRAY_SIZETYPE_HPP
#define LIBRAPID_ARRAY_SIZETYPE_HPP
/*
* This file defines the Shape class and some helper functions,
* including stride operations.
*/
namespace librapid {
namespace typetraits {
LIBRAPID_DEFINE_AS_TYPE_NO_TEMPLATE(Shape);
LIBRAPID_DEFINE_AS_TYPE_NO_TEMPLATE(MatrixShape);
LIBRAPID_DEFINE_AS_TYPE_NO_TEMPLATE(VectorShape);
} // namespace typetraits
class Shape {
public:
using SizeType = uint32_t;
static constexpr size_t MaxDimensions = LIBRAPID_MAX_ARRAY_DIMS;
LIBRAPID_ALWAYS_INLINE Shape() = default;
// mainly internally, but may serve some purpose I haven't yet thought of.
template<typename Scalar, size_t... Dimensions>
explicit LIBRAPID_ALWAYS_INLINE Shape(const FixedStorage<Scalar, Dimensions...> &fixed);
template<typename V>
LIBRAPID_ALWAYS_INLINE Shape(const std::initializer_list<V> &vals);
template<typename V>
explicit LIBRAPID_ALWAYS_INLINE Shape(const std::vector<V> &vals);
LIBRAPID_ALWAYS_INLINE Shape(const Shape &other) = default;
LIBRAPID_ALWAYS_INLINE Shape(const MatrixShape &other);
LIBRAPID_ALWAYS_INLINE Shape(const VectorShape &other);
LIBRAPID_ALWAYS_INLINE Shape(Shape &&other) noexcept = default;
template<size_t Dim>
LIBRAPID_ALWAYS_INLINE Shape(Shape &&other) noexcept;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::initializer_list<V> &vals) -> Shape &;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::vector<V> &vals) -> Shape &;
LIBRAPID_ALWAYS_INLINE auto operator=(Shape &&other) noexcept -> Shape & = default;
LIBRAPID_ALWAYS_INLINE auto operator=(const Shape &other) -> Shape & = default;
LIBRAPID_ALWAYS_INLINE static auto zeros(int dims) -> Shape;
LIBRAPID_ALWAYS_INLINE static auto ones(int dims) -> Shape;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) const
-> const SizeType &;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) -> SizeType &;
LIBRAPID_ALWAYS_INLINE auto operator==(const Shape &other) const -> bool;
LIBRAPID_ALWAYS_INLINE auto operator!=(const Shape &other) const -> bool;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> int;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto subshape(int start, int end) const -> Shape;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;
template<typename T_, typename Char, typename Ctx>
void str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const;
protected:
int m_dims;
std::array<SizeType, MaxDimensions> m_data;
};
class MatrixShape {
public:
using SizeType = uint32_t;
static constexpr size_t MaxDimensions = 2;
LIBRAPID_ALWAYS_INLINE MatrixShape() = default;
template<typename Scalar, size_t Rows, size_t Cols>
LIBRAPID_ALWAYS_INLINE explicit MatrixShape(const FixedStorage<Scalar, Rows, Cols> &fixed);
template<typename V>
LIBRAPID_ALWAYS_INLINE MatrixShape(const std::initializer_list<V> &vals);
template<typename V>
LIBRAPID_ALWAYS_INLINE explicit MatrixShape(const std::vector<V> &vals);
LIBRAPID_ALWAYS_INLINE MatrixShape(const Shape &other);
LIBRAPID_ALWAYS_INLINE MatrixShape(const MatrixShape &other) = default;
LIBRAPID_ALWAYS_INLINE MatrixShape(MatrixShape &&other) noexcept = default;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::initializer_list<V> &vals)
-> MatrixShape &;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::vector<V> &vals) -> MatrixShape &;
LIBRAPID_ALWAYS_INLINE MatrixShape &operator=(const MatrixShape &other) = default;
LIBRAPID_ALWAYS_INLINE MatrixShape &operator=(MatrixShape &&other) noexcept = default;
static LIBRAPID_ALWAYS_INLINE auto zeros() -> MatrixShape;
static LIBRAPID_ALWAYS_INLINE auto ones() -> MatrixShape;
static LIBRAPID_ALWAYS_INLINE auto zeros(size_t) -> MatrixShape;
static LIBRAPID_ALWAYS_INLINE auto ones(size_t) -> MatrixShape;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) const
-> const SizeType &;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) -> SizeType &;
LIBRAPID_ALWAYS_INLINE auto operator<=>(const MatrixShape &other) const = default;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto ndim() const -> int;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto subshape(int start, int end) const -> Shape;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;
template<typename T_, typename Char, typename Ctx>
void str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const;
private:
SizeType m_rows;
SizeType m_cols;
};
class VectorShape {
public:
using SizeType = uint32_t;
static constexpr size_t MaxDimensions = 1;
LIBRAPID_ALWAYS_INLINE VectorShape() = default;
template<typename Scalar, size_t Elements>
LIBRAPID_ALWAYS_INLINE explicit VectorShape(const FixedStorage<Scalar, Elements> &fixed);
template<typename V>
LIBRAPID_ALWAYS_INLINE VectorShape(const std::initializer_list<V> &vals);
template<typename V>
LIBRAPID_ALWAYS_INLINE explicit VectorShape(const std::vector<V> &vals);
LIBRAPID_ALWAYS_INLINE VectorShape(const Shape &other);
LIBRAPID_ALWAYS_INLINE VectorShape(const VectorShape &other) = default;
LIBRAPID_ALWAYS_INLINE VectorShape(VectorShape &&other) noexcept = default;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::initializer_list<V> &vals)
-> VectorShape &;
template<typename V>
LIBRAPID_ALWAYS_INLINE auto operator=(const std::vector<V> &vals) -> VectorShape &;
LIBRAPID_ALWAYS_INLINE VectorShape &operator=(const VectorShape &other) = default;
LIBRAPID_ALWAYS_INLINE VectorShape &operator=(VectorShape &&other) noexcept = default;
static LIBRAPID_ALWAYS_INLINE auto zeros() -> VectorShape;
static LIBRAPID_ALWAYS_INLINE auto ones() -> VectorShape;
static LIBRAPID_ALWAYS_INLINE auto zeros(size_t) -> VectorShape;
static LIBRAPID_ALWAYS_INLINE auto ones(size_t) -> VectorShape;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) const
-> const SizeType &;
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) -> SizeType &;
LIBRAPID_ALWAYS_INLINE auto operator<=>(const VectorShape &other) const = default;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto ndim() const -> int;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto subshape(int start, int end) const -> Shape;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;
template<typename T_, typename Char, typename Ctx>
void str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const;
private:
SizeType m_elements;
};
namespace detail {
template<typename T, size_t... Dims>
Shape shapeFromFixedStorage(const FixedStorage<T, Dims...> &) {
return Shape({Dims...});
}
} // namespace detail
template<typename Scalar, size_t... Dimensions>
LIBRAPID_ALWAYS_INLINE Shape::Shape(const FixedStorage<Scalar, Dimensions...> &) :
m_dims(sizeof...(Dimensions)), m_data({Dimensions...}) {}
template<typename V>
LIBRAPID_ALWAYS_INLINE Shape::Shape(const std::initializer_list<V> &vals) :
m_dims(vals.size()) {
for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); }
}
template<typename V>
LIBRAPID_ALWAYS_INLINE Shape::Shape(const std::vector<V> &vals) : m_dims(vals.size()) {
for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; }
}
LIBRAPID_ALWAYS_INLINE Shape::Shape(const MatrixShape &other) {
m_dims = 2;
m_data[0] = other[0];
m_data[1] = other[1];
}
LIBRAPID_ALWAYS_INLINE Shape::Shape(const VectorShape &other) {
m_dims = 1;
m_data[0] = other[0];
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto Shape::operator=(const std::initializer_list<V> &vals) -> Shape & {
m_dims = vals.size();
for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); }
return *this;
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto Shape::operator=(const std::vector<V> &vals) -> Shape & {
m_dims = vals.size();
for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; }
return *this;
}
LIBRAPID_ALWAYS_INLINE auto Shape::zeros(int dims) -> Shape {
Shape res;
res.m_dims = dims;
for (int i = 0; i < dims; ++i) res.m_data[i] = 0;
return res;
}
LIBRAPID_ALWAYS_INLINE auto Shape::ones(int dims) -> Shape {
Shape res;
res.m_dims = dims;
for (int i = 0; i < dims; ++i) res.m_data[i] = 1;
return res;
}
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto Shape::operator[](Index index) const
-> const SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
index < m_dims,
"Index {} out of bounds for Shape with {} dimensions",
index,
m_dims);
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
index >= 0,
"Index out of bounds. Must be greater than 0. Received {}",
index);
return m_data[index];
}
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto Shape::operator[](Index index) -> SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
index < m_dims,
"Index {} out of bounds for Shape with {} dimensions",
index,
m_dims);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::range_error, index >= 0, "Index {} out of bounds. Must be greater than 0", index);
return m_data[index];
}
LIBRAPID_ALWAYS_INLINE auto Shape::operator==(const Shape &other) const -> bool {
if (m_dims != other.m_dims) return false;
for (int i = 0; i < m_dims; ++i) {
if (m_data[i] != other.m_data[i]) return false;
}
return true;
}
LIBRAPID_ALWAYS_INLINE auto Shape::operator!=(const Shape &other) const -> bool {
return !(*this == other);
}
LIBRAPID_NODISCARD auto Shape::ndim() const -> int { return m_dims; }
LIBRAPID_NODISCARD auto Shape::subshape(int start, int end) const -> Shape {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
start <= end,
"Start index ({}) must not be greater than end index ({})",
start,
end);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range,
end <= m_dims,
"End index ({}) must be less than or equal to the number of dimensions ({}).",
end,
m_dims);
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
start >= 0,
"Start index ({}) must be greater than or equal to 0",
start);
Shape res;
res.m_dims = end - start;
for (int i = 0; i < res.m_dims; ++i) res.m_data[i] = m_data[i + start];
return res;
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto Shape::size() const -> size_t {
size_t res = 1;
for (int i = 0; i < m_dims; ++i) res *= m_data[i];
return res;
}
template<typename T_, typename Char, typename Ctx>
LIBRAPID_ALWAYS_INLINE void Shape::str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const {
fmt::format_to(ctx.out(), "Shape(");
for (int i = 0; i < m_dims; ++i) {
format.format(m_data[i], ctx);
if (i != m_dims - 1) fmt::format_to(ctx.out(), ", ");
}
fmt::format_to(ctx.out(), ")");
}
template<typename Scalar, size_t Rows, size_t Cols>
LIBRAPID_ALWAYS_INLINE MatrixShape::MatrixShape(const FixedStorage<Scalar, Rows, Cols> &) :
m_rows(Rows), m_cols(Cols) {}
template<typename V>
LIBRAPID_ALWAYS_INLINE MatrixShape::MatrixShape(const std::initializer_list<V> &vals) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
vals.size() <= 2,
"MatrixShape must be initialized with 2 values. Received {}",
vals.size());
if (vals.size() == 2) {
m_rows = *(vals.begin());
m_cols = *(vals.begin() + 1);
} else if (vals.size() == 1) {
m_rows = *(vals.begin());
m_cols = 1;
} else {
m_rows = 0;
m_cols = 0;
}
}
template<typename V>
LIBRAPID_ALWAYS_INLINE MatrixShape::MatrixShape(const std::vector<V> &vals) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
vals.size() <= 2,
"MatrixShape must be initialized with 2 values. Received {}",
vals.size());
if (vals.size() == 2) {
m_rows = vals[0];
m_cols = vals[1];
} else if (vals.size() == 1) {
m_rows = vals[0];
m_cols = 1;
} else {
m_rows = 0;
m_cols = 0;
}
}
LIBRAPID_ALWAYS_INLINE MatrixShape::MatrixShape(const Shape &other) {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::invalid_argument,
other.ndim() <= 2,
"MatrixShape must be initialized with 2 dimension, but received {}",
other.ndim());
if (other.ndim() == 2) {
m_rows = other[0];
m_cols = other[1];
} else if (other.ndim() == 1) {
m_rows = other[0];
m_cols = 1;
} else {
m_rows = 0;
m_cols = 0;
}
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto MatrixShape::operator=(const std::initializer_list<V> &vals)
-> MatrixShape & {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::invalid_argument,
vals.size() <= 2,
"MatrixShape must be initialized with 2 values, but received {}",
vals.size());
if (vals.size() == 2) {
m_rows = *(vals.begin());
m_cols = *(vals.begin() + 1);
} else if (vals.size() == 1) {
m_rows = *(vals.begin());
m_cols = 1;
} else {
m_rows = 0;
m_cols = 0;
}
return *this;
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto MatrixShape::operator=(const std::vector<V> &vals)
-> MatrixShape & {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::invalid_argument,
vals.size() <= 2,
"MatrixShape must be initialized with 2 values, but received {}",
vals.size());
if (vals.size() == 2) {
m_rows = vals[0];
m_cols = vals[1];
} else if (vals.size() == 1) {
m_rows = vals[0];
m_cols = 1;
} else {
m_rows = 0;
m_cols = 0;
}
return *this;
}
LIBRAPID_ALWAYS_INLINE auto MatrixShape::zeros() -> MatrixShape { return MatrixShape({0, 0}); }
LIBRAPID_ALWAYS_INLINE auto MatrixShape::ones() -> MatrixShape { return MatrixShape({1, 1}); }
LIBRAPID_ALWAYS_INLINE auto MatrixShape::zeros(size_t) -> MatrixShape {
return MatrixShape({0, 0});
}
LIBRAPID_ALWAYS_INLINE auto MatrixShape::ones(size_t) -> MatrixShape {
return MatrixShape({1, 1});
}
template<typename Index>
LIBRAPID_ALWAYS_INLINE auto MatrixShape::operator[](Index index) const -> const SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
index < 2,
"Index {} out of bounds for MatrixShape with 2 dimensions",
index);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range, index >= 0, "Index {} out of bounds. Must be greater than 0", index);
return index == 0 ? m_rows : m_cols;
}
template<typename Index>
LIBRAPID_ALWAYS_INLINE auto MatrixShape::operator[](Index index) -> SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
index < 2,
"Index {} out of bounds for MatrixShape with 2 dimensions",
index);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range, index >= 0, "Index {} out of bounds. Must be greater than 0", index);
return index == 0 ? m_rows : m_cols;
}
LIBRAPID_ALWAYS_INLINE constexpr auto MatrixShape::ndim() const -> int { return 2; }
LIBRAPID_ALWAYS_INLINE auto MatrixShape::subshape(int start, int end) const -> Shape {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
start <= end,
"Start index ({}) must not be greater than end index ({})",
start,
end);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range,
end <= 2,
"End index ({}) must be less than or equal to the number of dimensions (2).",
end);
LIBRAPID_ASSERT_WITH_EXCEPTION(std::out_of_range,
start >= 0,
"Start index ({}) must be greater than or equal to 0",
start);
Shape res = Shape::zeros(2);
res[0] = m_rows;
res[1] = m_cols;
return res.subshape(start, end);
}
LIBRAPID_ALWAYS_INLINE auto MatrixShape::size() const -> size_t { return m_rows * m_cols; }
template<typename T_, typename Char, typename Ctx>
LIBRAPID_ALWAYS_INLINE void MatrixShape::str(const fmt::formatter<T_, Char> &format,
Ctx &ctx) const {
fmt::format_to(ctx.out(), "MatrixShape(");
format.format(m_rows, ctx);
fmt::format_to(ctx.out(), ", ");
format.format(m_cols, ctx);
fmt::format_to(ctx.out(), ")");
}
template<typename Scalar, size_t Elements>
LIBRAPID_ALWAYS_INLINE VectorShape::VectorShape(const FixedStorage<Scalar, Elements> &) :
m_elements(Elements) {}
template<typename V>
LIBRAPID_ALWAYS_INLINE VectorShape::VectorShape(const std::initializer_list<V> &vals) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
vals.size() == 1,
"MatrixShape must be initialized with 1 value. Received {}",
vals.size());
m_elements = *(vals.begin());
}
template<typename V>
LIBRAPID_ALWAYS_INLINE VectorShape::VectorShape(const std::vector<V> &vals) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
vals.size() == 1,
"MatrixShape must be initialized with 1 value. Received {}",
vals.size());
m_elements = vals[0];
}
LIBRAPID_ALWAYS_INLINE VectorShape::VectorShape(const Shape &other) {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::invalid_argument,
other.ndim() == 1,
"VectorShape must be initialized with 1 dimension, but received {}",
other.ndim());
m_elements = other[0];
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto VectorShape::operator=(const std::initializer_list<V> &vals)
-> VectorShape & {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
vals.size() == 1,
"MatrixShape must be initialized with 1 value. Received {}",
vals.size());
m_elements = *(vals.begin());
return *this;
}
template<typename V>
LIBRAPID_ALWAYS_INLINE auto VectorShape::operator=(const std::vector<V> &vals)
-> VectorShape & {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::runtime_error,
vals.size() == 1,
"MatrixShape must be initialized with 1 value. Received {}",
vals.size());
m_elements = vals[0];
return *this;
}
LIBRAPID_ALWAYS_INLINE auto VectorShape::zeros() -> VectorShape { return VectorShape({0}); }
LIBRAPID_ALWAYS_INLINE auto VectorShape::ones() -> VectorShape { return VectorShape({1}); }
LIBRAPID_ALWAYS_INLINE auto VectorShape::zeros(size_t) -> VectorShape {
return VectorShape({0});
}
LIBRAPID_ALWAYS_INLINE auto VectorShape::ones(size_t) -> VectorShape {
return VectorShape({1});
}
template<typename Index>
LIBRAPID_ALWAYS_INLINE auto VectorShape::operator[](Index index) const -> const SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
index < 1,
"Index {} out of bounds for VectorShape with 1 dimension",
index);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::range_error, index >= 0, "Index {} out of bounds. Must be greater than 0", index);
return m_elements;
}
template<typename Index>
LIBRAPID_ALWAYS_INLINE auto VectorShape::operator[](Index index) -> SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
index < 1,
"Index {} out of bounds for VectorShape with 1 dimension",
index);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::range_error, index >= 0, "Index {} out of bounds. Must be greater than 0", index);
return m_elements;
}
LIBRAPID_ALWAYS_INLINE constexpr auto VectorShape::ndim() const -> int { return 1; }
LIBRAPID_ALWAYS_INLINE auto VectorShape::subshape(int start, int end) const -> Shape {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
start <= end,
"Start index ({}) must not be greater than end index ({})",
start,
end);
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::range_error,
end <= 1,
"End index ({}) must be less than or equal to the number of dimensions (1).",
end);
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
start >= 0,
"Start index ({}) must be greater than or equal to 0",
start);
return Shape::zeros(1);
}
LIBRAPID_ALWAYS_INLINE auto VectorShape::size() const -> size_t { return m_elements; }
template<typename T_, typename Char, typename Ctx>
LIBRAPID_ALWAYS_INLINE void VectorShape::str(const fmt::formatter<T_, Char> &format,
Ctx &ctx) const {
fmt::format_to(ctx.out(), "VectorShape(");
format.format(m_elements, ctx);
fmt::format_to(ctx.out(), ")");
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const Shape &lhs,
const MatrixShape &rhs) -> bool {
return lhs.ndim() == 2 && lhs[0] == rhs[0] && lhs[1] == rhs[1];
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const MatrixShape &lhs,
const Shape &rhs) -> bool {
return rhs == lhs;
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const Shape &lhs,
const MatrixShape &rhs) -> bool {
return !(lhs == rhs);
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const MatrixShape &lhs,
const Shape &rhs) -> bool {
return !(lhs == rhs);
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const Shape &lhs,
const VectorShape &rhs) -> bool {
return lhs.ndim() == 1 && lhs[0] == rhs[0];
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const VectorShape &lhs,
const Shape &rhs) -> bool {
return rhs == lhs;
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const Shape &lhs,
const VectorShape &rhs) -> bool {
return !(lhs == rhs);
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const VectorShape &lhs,
const Shape &rhs) -> bool {
return !(lhs == rhs);
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const MatrixShape &,
const VectorShape &) -> bool {
// A vector cannot have the same shape as a matrix since it has a different number of
// dimensions by definition
return false;
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const VectorShape &,
const MatrixShape &) -> bool {
return false;
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const MatrixShape &lhs,
const VectorShape &rhs) -> bool {
return !(lhs == rhs);
}
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const VectorShape &lhs,
const MatrixShape &rhs) -> bool {
return !(lhs == rhs);
}
namespace typetraits {
template<typename T>
struct IsSizeType : std::false_type {};
template<>
struct IsSizeType<Shape> : std::true_type {};
template<>
struct IsSizeType<MatrixShape> : std::true_type {};
} // namespace typetraits
template<typename First, typename Second, typename... Rest>
requires(typetraits::IsSizeType<First>::value && typetraits::IsSizeType<Second>::value &&
(typetraits::IsSizeType<Rest>::value && ...))
LIBRAPID_NODISCARD LIBRAPID_INLINE bool
shapesMatch(const std::tuple<First, Second, Rest...> &shapes) {
if constexpr (sizeof...(Rest) == 0) {
return std::get<0>(shapes) == std::get<1>(shapes);
} else {
return std::get<0>(shapes) == std::get<1>(shapes) &&
shapesMatch(std::apply(
[](auto, auto, auto... rest) { return std::make_tuple(rest...); }, shapes));
}
}
template<typename First>
requires(typetraits::IsSizeType<First>::value)
LIBRAPID_NODISCARD LIBRAPID_INLINE bool
shapesMatch(const std::tuple<First> &shapes) {
return true;
}
namespace detail {
template<typename First, typename Second>
struct ShapeTypeHelperImpl {
using Type = std::false_type;
};
template<>
struct ShapeTypeHelperImpl<Shape, Shape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<Shape, MatrixShape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<MatrixShape, Shape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<Shape, VectorShape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<VectorShape, Shape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<MatrixShape, MatrixShape> {
using Type = MatrixShape;
};
template<>
struct ShapeTypeHelperImpl<MatrixShape, VectorShape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<VectorShape, MatrixShape> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<VectorShape, VectorShape> {
using Type = VectorShape;
};
template<typename NonFalseType>
struct ShapeTypeHelperImpl<NonFalseType, std::false_type> {
using Type = NonFalseType;
};
template<typename NonFalseType>
struct ShapeTypeHelperImpl<std::false_type, NonFalseType> {
using Type = Shape;
};
template<>
struct ShapeTypeHelperImpl<std::false_type, std::false_type> {
using Type = VectorShape; // Fastest
};
template<typename... Args>
struct ShapeTypeHelper;
template<typename First>
struct ShapeTypeHelper<First> {
using Type = First;
};
template<typename First, typename Second>
struct ShapeTypeHelper<First, Second> {
using Type = typename ShapeTypeHelperImpl<First, Second>::Type;
};
template<typename First, typename Second, typename... Rest>
struct ShapeTypeHelper<First, Second, Rest...> {
using FirstResult = typename ShapeTypeHelperImpl<First, Second>::Type;
using Type = typename ShapeTypeHelper<FirstResult, Rest...>::Type;
};
template<typename T>
struct SubscriptShapeType {
using Type = Shape;
};
template<>
struct SubscriptShapeType<MatrixShape> {
using Type = VectorShape;
};
template<>
struct SubscriptShapeType<VectorShape> {
using Type = Shape;
};
} // namespace detail
} // namespace librapid
// Support FMT printing
#ifdef FMT_API
template<>
struct fmt::formatter<librapid::Shape> {
private:
using Type = librapid::Shape;
using SizeType = librapid::Shape::SizeType;
using Base = fmt::formatter<SizeType, char>;
Base m_base;
public:
template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
return m_base.parse(ctx);
}
template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_base, ctx);
return ctx.out();
}
};
template<>
struct fmt::formatter<librapid::MatrixShape> {
private:
using Type = librapid::MatrixShape;
using SizeType = librapid::MatrixShape::SizeType;
using Base = fmt::formatter<SizeType, char>;
Base m_base;
public:
template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
return m_base.parse(ctx);
}
template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_base, ctx);
return ctx.out();
}
};
#endif // FMT_API
#endif // LIBRAPID_ARRAY_SIZETYPE_HPP