Program Listing for File generalArrayView.hpp#

Return to documentation for file (librapid/include/librapid/array/generalArrayView.hpp)

#ifndef LIBRAPID_ARRAY_ARRAY_VIEW_HPP
#define LIBRAPID_ARRAY_ARRAY_VIEW_HPP

namespace librapid {
    namespace typetraits {
        template<typename T, typename S>
        struct TypeInfo<array::GeneralArrayView<T, S>> {
            static constexpr detail::LibRapidType type = detail::LibRapidType::GeneralArrayView;
            using Scalar                               = typename TypeInfo<std::decay_t<T>>::Scalar;
            using Backend       = typename TypeInfo<std::decay_t<T>>::Backend;
            using ArrayViewType = std::decay_t<T>;
            using ShapeType     = typename TypeInfo<ArrayViewType>::ShapeType;
            using StorageType   = typename TypeInfo<ArrayViewType>::StorageType;
            static constexpr bool allowVectorisation = false;
        };

        LIBRAPID_DEFINE_AS_TYPE(typename T COMMA typename S, array::GeneralArrayView<T COMMA S>);
    } // namespace typetraits

    template<typename T>
    LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayView(T &&array) {
        using ShapeType = typename std::decay_t<T>::ShapeType;
        return array::GeneralArrayView<T, ShapeType>(std::forward<T>(array));
    }

    template<typename ShapeType, typename T>
    LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayViewShapeModifier(T &&array) {
        return array::GeneralArrayView<T, ShapeType>(std::forward<T>(array));
    }

    namespace array {
        template<typename ArrayViewType,
                 typename ArrayViewShapeType = typename std::decay_t<ArrayViewType>::ShapeType>
        class GeneralArrayView {
        public:
            using BaseType       = typename std::decay_t<ArrayViewType>;
            using Scalar         = typename typetraits::TypeInfo<BaseType>::Scalar;
            using Reference      = BaseType &;
            using ConstReference = const BaseType &;
            using Backend        = typename typetraits::TypeInfo<BaseType>::Backend;
            using ShapeType      = ArrayViewShapeType;
            using StrideType     = Stride<ShapeType>;
            using StorageType    = typename typetraits::TypeInfo<BaseType>::StorageType;
            using ArrayType      = array::ArrayContainer<ShapeType, StorageType>;
            using Iterator       = detail::ArrayIterator<GeneralArrayView>;

            GeneralArrayView() = delete;

            // LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &array);

            LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &&array);

            LIBRAPID_ALWAYS_INLINE GeneralArrayView(const GeneralArrayView &other);

            LIBRAPID_ALWAYS_INLINE GeneralArrayView(GeneralArrayView &&other);

            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator=(const GeneralArrayView &other);

            GeneralArrayView &operator=(GeneralArrayView &&other) noexcept = default;

            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator=(const Scalar &scalar);

            template<typename ShapeType_, typename StorageType_>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &
            operator=(const ArrayContainer<ShapeType_, StorageType_> &other);

            template<typename desc, typename Functor, typename... Args>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &
            operator=(const detail::Function<desc, Functor, Args...> &function);

            template<typename TransposeType>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &
            operator=(const array::Transpose<TransposeType> &transpose);

            template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
                     typename StorageTypeB, typename Alpha, typename Beta>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &
            operator=(const linalg::ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB,
                                                  StorageTypeB, Alpha, Beta> &matmul);

            LIBRAPID_ALWAYS_INLINE const auto operator[](int64_t index) const;

            LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index);

            template<typename CAST = Scalar>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CAST get() const;

            template<typename CAST>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator CAST() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t size() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StrideType stride() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t offset() const;

            LIBRAPID_ALWAYS_INLINE void setShape(const ShapeType &shape);

            LIBRAPID_ALWAYS_INLINE void setStride(const StrideType &stride);

            LIBRAPID_ALWAYS_INLINE void setOffset(const int64_t &offset);

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const;

            template<typename T>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator+=(const T &other);

            template<typename T>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator-=(const T &other);

            template<typename T>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator*=(const T &other);

            template<typename T>
            LIBRAPID_ALWAYS_INLINE GeneralArrayView &operator/=(const T &other);

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType eval() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;

            template<typename T, typename Char, size_t N, typename Ctx>
            void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
                     const char (&formatString)[N], Ctx &ctx) const;

        private:
            ArrayViewType m_ref;
            ShapeType m_shape;
            StrideType m_stride;
            int64_t m_offset = 0;
        };

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
          ArrayViewType &&array) :
                m_ref(array),
                m_shape(array.shape()), m_stride(array.shape()) {}

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
          const GeneralArrayView &other) :
                m_ref(other.m_ref),
                m_shape(other.m_shape), m_stride(other.m_stride) {}

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::GeneralArrayView(
          GeneralArrayView &&other) :
                m_ref(other.m_ref),
                m_shape(other.m_shape), m_stride(other.m_stride), m_offset(other.m_offset) {}

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(const Scalar &scalar) {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::invalid_argument,
                                           m_shape.ndim() == 0,
                                           "Cannot assign to a non-scalar ArrayView with {}",
                                           m_shape);
            m_ref.storage()[m_offset] = static_cast<Scalar>(scalar);
            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
          const GeneralArrayView &other) {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
                                           m_shape.operator==(other.shape()),
                                           "GeneralArrayView assignment shape mismatch. {} vs {}",
                                           m_shape,
                                           other.shape());

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = other.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename ShapeType_, typename StorageType_>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
          const ArrayContainer<ShapeType_, StorageType_> &other) {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
                                           m_shape.operator==(other.shape()),
                                           "GeneralArrayView assignment shape mismatch. {} vs {}",
                                           m_shape,
                                           other.shape());

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = other.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
          const detail::Function<desc, Functor, Args...> &function) -> GeneralArrayView & {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
                                           m_shape.operator==(function.shape()),
                                           "GeneralArrayView assignment shape mismatch. {} vs {}",
                                           m_shape,
                                           function.shape());

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = function.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename TransposeType>
        LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
          const array::Transpose<TransposeType> &transpose) -> GeneralArrayView & {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
                                           m_shape.operator==(transpose.shape()),
                                           "GeneralArrayView assignment shape mismatch. {} vs {}",
                                           m_shape,
                                           transpose.shape());

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = transpose.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
                 typename StorageTypeB, typename Alpha, typename Beta>
        LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator=(
          const linalg::ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
                                      Beta> &matmul) -> GeneralArrayView & {
            LIBRAPID_ASSERT_WITH_EXCEPTION(std::range_error,
                                           m_shape.operator==(matmul.shape()),
                                           "GeneralArrayView assignment shape mismatch. {} vs {}",
                                           m_shape,
                                           matmul.shape());

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = matmul.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE const auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator[](int64_t index) const {
            LIBRAPID_ASSERT_WITH_EXCEPTION(
              std::out_of_range,
              index >= 0 && index < static_cast<int64_t>(m_shape[0]),
              "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
              index,
              m_shape[0]);
            auto view         = createGeneralArrayViewShapeModifier<Shape>(m_ref);
            const auto stride = Stride(m_shape);
            view.setShape(m_shape.subshape(1, ndim()));
            if (ndim() == 1)
                view.setStride(Stride<Shape>({1}));
            else
                view.setStride(stride.substride(1, ndim()));
            view.setOffset(m_offset + index * stride[0]);
            return view;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator[](int64_t index) {
            LIBRAPID_ASSERT_WITH_EXCEPTION(
              std::out_of_range,
              index >= 0 && index < static_cast<int64_t>(m_shape[0]),
              "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
              index,
              m_shape[0]);
            auto view         = createGeneralArrayViewShapeModifier<Shape>(m_ref);
            const auto stride = Stride(m_shape);
            view.setShape(m_shape.subshape(1, ndim()));
            if (ndim() == 1)
                view.setStride(Stride<Shape>({1}));
            else
                view.setStride(stride.substride(1, ndim()));
            view.setOffset(m_offset + index * stride[0]);
            return view;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename CAST>
        LIBRAPID_ALWAYS_INLINE CAST
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::get() const {
            LIBRAPID_ASSERT_WITH_EXCEPTION(
              std::invalid_argument,
              m_shape.ndim() == 0,
              "Can only cast a scalar ArrayView to a salar object. ArrayView had {}",
              m_shape);
            return scalar(0);
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename CAST>
        LIBRAPID_ALWAYS_INLINE
          GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator CAST() const {
            return get();
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE int64_t
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::size() const {
            return m_shape.size();
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::shape() const -> ShapeType {
            return m_shape;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::stride() const -> StrideType {
            return m_stride;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::offset() const -> int64_t {
            return m_offset;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE void
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setShape(const ShapeType &shape) {
            m_shape = shape;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE void
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setStride(const StrideType &stride) {
            m_stride = stride;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE void
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::setOffset(const int64_t &offset) {
            m_offset = offset;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::ndim() const -> int64_t {
            return m_shape.ndim();
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::scalar(int64_t index) const -> auto {
            if (ndim() == 0) return m_ref.scalar(m_offset);

            ShapeType tmp   = ShapeType::zeros(ndim());
            tmp[ndim() - 1] = index % m_shape[ndim() - 1];
            for (int64_t i = ndim() - 2; i >= 0; --i) {
                index /= m_shape[i + 1];
                tmp[i] = index % m_shape[i];
            }
            int64_t offset = 0;
            for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; }
            return m_ref.scalar(m_offset + offset);
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename T>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator+=(const T &other) {
            *this = *this + other;
            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename T>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator-=(const T &other) {
            *this = *this - other;
            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename T>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator*=(const T &other) {
            *this = *this * other;
            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        template<typename T>
        LIBRAPID_ALWAYS_INLINE GeneralArrayView<ArrayViewType, ArrayViewShapeType> &
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::operator/=(const T &other) {
            *this = *this / other;
            return *this;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::eval() const -> ArrayType {
            ArrayType res(m_shape);
            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                res.storage()[d++] = m_ref.scalar(p + m_offset);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return res;
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto
        GeneralArrayView<ArrayViewType, ArrayViewShapeType>::begin() const -> Iterator {
            return Iterator(*this, 0);
        }

        template<typename ArrayViewType, typename ArrayViewShapeType>
        LIBRAPID_ALWAYS_INLINE auto GeneralArrayView<ArrayViewType, ArrayViewShapeType>::end() const
          -> Iterator {
            return Iterator(*this, m_shape[0]);
        }
    } // namespace array
} // namespace librapid

// Support FMT printing
ARRAY_TYPE_FMT_IML(typename T COMMA typename S, librapid::array::GeneralArrayView<T COMMA S>)
LIBRAPID_SIMPLE_IO_NORANGE(typename T COMMA typename S,
                           librapid::array::GeneralArrayView<T COMMA S>)

#endif // LIBRAPID_ARRAY_ARRAY_VIEW_HPP