Program Listing for File generalArrayView.hpp#
↰ Return to documentation for file (librapid/include/librapid/array/generalArrayView.hpp)
#ifndef LIBRAPID_ARRAY_ARRAY_VIEW_HPP
#define LIBRAPID_ARRAY_ARRAY_VIEW_HPP
namespace librapid {
namespace typetraits {
template<typename T, typename S>
struct TypeInfo<array::GeneralArrayView<T, S>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::GeneralArrayView;
using Scalar = typename TypeInfo<std::decay_t<T>>::Scalar;
using Backend = typename TypeInfo<std::decay_t<T>>::Backend;
using ArrayViewType = std::decay_t<T>;
using ShapeType = typename TypeInfo<ArrayViewType>::ShapeType;
using StorageType = typename TypeInfo<ArrayViewType>::StorageType;
static constexpr bool allowVectorisation = false;
};
LIBRAPID_DEFINE_AS_TYPE(typename T COMMA typename S, array::GeneralArrayView<T COMMA S>);
} // namespace typetraits
template<typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayView(T &&array) {
using ShapeType = typename std::decay_t<T>::ShapeType;
return array::GeneralArrayView<T, ShapeType>(std::forward<T>(array));
}
template<typename ShapeType, typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayViewShapeModifier(T &&array) {
return array::GeneralArrayView<T, ShapeType>(std::forward<T>(array));
}
namespace array {
template<typename ArrayViewType,
typename ArrayViewShapeType = typename std::decay_t<ArrayViewType>::ShapeType>
class GeneralArrayView {
public:
using BaseType = typename std::decay_t<ArrayViewType>;
using Scalar = typename typetraits::TypeInfo<BaseType>::Scalar;
using Reference = BaseType &;
using ConstReference = const BaseType &;
using Backend = typename typetraits::TypeInfo<BaseType>::Backend;
using ShapeType = ArrayViewShapeType;
using StrideType = Stride<ShapeType>;
using StorageType = typename typetraits::TypeInfo<BaseType>::StorageType;
using ArrayType = array::ArrayContainer<ShapeType, StorageType>;
using Iterator = detail::ArrayIterator<GeneralArrayView>;
GeneralArrayView() = delete;
// LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &array);
LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &&array);
LIBRAPID_ALWAYS_INLINE GeneralArrayView(const GeneralArrayView &other);
LIBRAPID_ALWAYS_INLINE GeneralArrayView(GeneralArrayView &&other);
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator=(const GeneralArrayView &other);
GeneralArrayView &operator=(GeneralArrayView &&other) noexcept = default;
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator=(const Scalar &scalar);
template<typename ShapeType_, typename StorageType_>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &
operator=(const ArrayContainer<ShapeType_, StorageType_> &other);
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &
operator=(const detail::Function<desc, Functor, Args...> &function);
template<typename TransposeType>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &
operator=(const array::Transpose<TransposeType> &transpose);
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &
operator=(const linalg::ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB,
StorageTypeB, Alpha, Beta> &matmul);
LIBRAPID_ALWAYS_INLINE const auto operator[](int64_t index) const;
LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index);
template<typename CAST = Scalar>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CAST get() const;
template<typename CAST>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator CAST() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t size() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StrideType stride() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t offset() const;
LIBRAPID_ALWAYS_INLINE void setShape(const ShapeType &shape);
LIBRAPID_ALWAYS_INLINE void setStride(const StrideType &stride);
LIBRAPID_ALWAYS_INLINE void setOffset(const int64_t &offset);
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const;
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator+=(const T &other);
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator-=(const T &other);
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator*=(const T &other);
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator/=(const T &other);
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType eval() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;
template<typename T, typename Char, size_t N, typename Ctx>
void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
const char (&formatString)[N], Ctx &ctx) const;
private:
ArrayViewType m_ref;
ShapeType m_shape;
StrideType m_stride;
int64_t m_offset = 0;
};
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
ArrayViewType &&array) :
m_ref(array),
m_shape(array.shape()), m_stride(array.shape()) {}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
const GeneralArrayView &other) :
m_ref(other.m_ref),
m_shape(other.m_shape), m_stride(other.m_stride) {}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
GeneralArrayView &&other) :
m_ref(other.m_ref),
m_shape(other.m_shape), m_stride(other.m_stride), m_offset(other.m_offset) {}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(const Scalar &scalar) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
m_shape.ndim() == 0,
"Cannot assign to a non-scalar ArrayView with {}",
m_shape);
m_ref.storage()[m_offset] = static_cast<Scalar>(scalar);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
const GeneralArrayView &other) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
m_shape.operator==(other.shape()),
"GeneralArrayView assignment shape mismatch. {} vs {}",
m_shape,
other.shape());
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
m_ref.storage()[p + m_offset] = other.scalar(d++);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename ShapeType_, typename StorageType_>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
const ArrayContainer<ShapeType_, StorageType_> &other) {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
m_shape.operator==(other.shape()),
"GeneralArrayView assignment shape mismatch. {} vs {}",
m_shape,
other.shape());
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
m_ref.storage()[p + m_offset] = other.scalar(d++);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
const detail::Function<desc, Functor, Args...> &function) -> GeneralArrayView & {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
m_shape.operator==(function.shape()),
"GeneralArrayView assignment shape mismatch. {} vs {}",
m_shape,
function.shape());
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
m_ref.storage()[p + m_offset] = function.scalar(d++);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename TransposeType>
LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
const array::Transpose<TransposeType> &transpose) -> GeneralArrayView & {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
m_shape.operator==(transpose.shape()),
"GeneralArrayView assignment shape mismatch. {} vs {}",
m_shape,
transpose.shape());
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
m_ref.storage()[p + m_offset] = transpose.scalar(d++);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
const linalg::ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
Beta> &matmul) -> GeneralArrayView & {
LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
m_shape.operator==(matmul.shape()),
"GeneralArrayView assignment shape mismatch. {} vs {}",
m_shape,
matmul.shape());
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
m_ref.storage()[p + m_offset] = matmul.scalar(d++);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE const auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator[](int64_t index) const {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range,
index >= 0 && index < static_cast<int64_t>(m_shape[0]),
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
index,
m_shape[0]);
auto view = createGeneralArrayViewShapeModifier<Shape>(m_ref);
const auto stride = Stride(m_shape);
view.setShape(m_shape.subshape(1, ndim()));
if (ndim() == 1)
view.setStride(Stride<Shape>({1}));
else
view.setStride(stride.substride(1, ndim()));
view.setOffset(m_offset + index * stride[0]);
return view;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator[](int64_t index) {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::out_of_range,
index >= 0 && index < static_cast<int64_t>(m_shape[0]),
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
index,
m_shape[0]);
auto view = createGeneralArrayViewShapeModifier<Shape>(m_ref);
const auto stride = Stride(m_shape);
view.setShape(m_shape.subshape(1, ndim()));
if (ndim() == 1)
view.setStride(Stride<Shape>({1}));
else
view.setStride(stride.substride(1, ndim()));
view.setOffset(m_offset + index * stride[0]);
return view;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename CAST>
LIBRAPID_ALWAYS_INLINE CAST
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::get() const {
LIBRAPID_ASSERT_WITH_EXCEPTION(
std::invalid_argument,
m_shape.ndim() == 0,
"Can only cast a scalar ArrayView to a salar object. ArrayView had {}",
m_shape);
return scalar(0);
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename CAST>
LIBRAPID_ALWAYS_INLINE
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator CAST() const {
return get();
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE int64_t
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::size() const {
return m_shape.size();
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::shape() const -> ShapeType {
return m_shape;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::stride() const -> StrideType {
return m_stride;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::offset() const -> int64_t {
return m_offset;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE void
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setShape(const ShapeType &shape) {
m_shape = shape;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE void
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setStride(const StrideType &stride) {
m_stride = stride;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE void
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setOffset(const int64_t &offset) {
m_offset = offset;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::ndim() const -> int64_t {
return m_shape.ndim();
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::scalar(int64_t index) const -> auto {
if (ndim() == 0) return m_ref.scalar(m_offset);
ShapeType tmp = ShapeType::zeros(ndim());
tmp[ndim() - 1] = index % m_shape[ndim() - 1];
for (int64_t i = ndim() - 2; i >= 0; --i) {
index /= m_shape[i + 1];
tmp[i] = index % m_shape[i];
}
int64_t offset = 0;
for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; }
return m_ref.scalar(m_offset + offset);
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator+=(const T &other) {
*this = *this + other;
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator-=(const T &other) {
*this = *this - other;
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator*=(const T &other) {
*this = *this * other;
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
template<typename T>
LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator/=(const T &other) {
*this = *this / other;
return *this;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::eval() const -> ArrayType {
ArrayType res(m_shape);
ShapeType coord = ShapeType::zeros(m_shape.ndim());
int64_t d = 0, p = 0;
int64_t idim = 0, adim = 0;
const int64_t ndim = m_shape.ndim();
do {
res.storage()[d++] = m_ref.scalar(p + m_offset);
for (idim = 0; idim < ndim; ++idim) {
adim = ndim - idim - 1;
if (++coord[adim] == m_shape[adim]) {
coord[adim] = 0;
p = p - (m_shape[adim] - 1) * m_stride[adim];
} else {
p = p + m_stride[adim];
break;
}
}
} while (idim < ndim);
return res;
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto
GeneralArrayView<ArrayViewType, ArrayViewShapeType>::begin() const -> Iterator {
return Iterator(*this, 0);
}
template<typename ArrayViewType, typename ArrayViewShapeType>
LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::end() const
-> Iterator {
return Iterator(*this, m_shape[0]);
}
} // namespace array
} // namespace librapid
// Support FMT printing
ARRAY_TYPE_FMT_IML(typename T COMMA typename S, librapid::array::GeneralArrayView<T COMMA S>)
LIBRAPID_SIMPLE_IO_NORANGE(typename T COMMA typename S,
librapid::array::GeneralArrayView<T COMMA S>)
#endif // LIBRAPID_ARRAY_ARRAY_VIEW_HPP