Program Listing for File arrayView.hpp#

Return to documentation for file (librapid/include/librapid/array/arrayView.hpp)

#ifndef LIBRAPID_ARRAY_ARRAY_VIEW_HPP
#define LIBRAPID_ARRAY_ARRAY_VIEW_HPP

namespace librapid {
    namespace typetraits {
        template<typename T>
        struct TypeInfo<array::ArrayView<T>> {
            static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayView;
            using Scalar                               = typename TypeInfo<std::decay_t<T>>::Scalar;
            using Backend                            = typename TypeInfo<std::decay_t<T>>::Backend;
            static constexpr bool allowVectorisation = false;
        };

        LIBRAPID_DEFINE_AS_TYPE(typename T, array::ArrayView<T>);
    } // namespace typetraits

    namespace array {
        template<typename T>
        class ArrayView {
        public:
            // using ArrayType       = T;
            using BaseType       = typename std::decay_t<T>;
            using Scalar         = typename typetraits::TypeInfo<BaseType>::Scalar;
            using Reference      = BaseType &;
            using ConstReference = const BaseType &;
            using Backend        = typename typetraits::TypeInfo<BaseType>::Backend;
            using ArrayType      = Array<Scalar, Backend>;
            using StrideType     = typename ArrayType::StrideType;
            using ShapeType      = typename ArrayType::ShapeType;
            using Iterator       = detail::ArrayIterator<ArrayView>;

            ArrayView() = delete;

            explicit ArrayView(T &array);

            explicit ArrayView(T &&array) = delete;

            ArrayView(const ArrayView &other) = default;

            ArrayView(ArrayView &&other) = default;

            ArrayView &operator=(const ArrayView &other) = default;

            // ArrayView &operator=(ArrayView &&other) noexcept = default;

            ArrayView &operator=(const Scalar &scalar);

            template<typename RefType>
            ArrayView &operator=(const ArrayRef<RefType> &other);

            const ArrayView<T> operator[](int64_t index) const;

            ArrayView<T> operator[](int64_t index);

            template<typename CAST = Scalar>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CAST get() const;

            template<typename CAST>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator CAST() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StrideType stride() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t offset() const;

            void setShape(const ShapeType &shape);

            void setStride(const StrideType &stride);

            void setOffset(const int64_t &offset);

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType eval() const;

            LIBRAPID_NODISCARD Iterator begin() const;
            LIBRAPID_NODISCARD Iterator end() const;

            LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const;

        private:
            T &m_ref;
            ShapeType m_shape;
            StrideType m_stride;
            int64_t m_offset = 0;
        };

        template<typename T>
        ArrayView<T>::ArrayView(T &array) :
                m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {}

        template<typename T>
        ArrayView<T> &ArrayView<T>::operator=(const Scalar &scalar) {
            LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign to a non-scalar ArrayView.");
            m_ref.storage()[m_offset] = static_cast<Scalar>(scalar);
            return *this;
        }

        template<typename T>
        template<typename RefType>
        ArrayView<T> &ArrayView<T>::operator=(const ArrayRef<RefType> &other) {
            LIBRAPID_ASSERT(m_shape.operator==(other.shape()), "Cannot assign to a non-scalar ArrayView.");

            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                m_ref.storage()[p + m_offset] = other.scalar(d++);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);
        }

        template<typename T>
        auto ArrayView<T>::operator[](int64_t index) const -> const ArrayView<T> {
            LIBRAPID_ASSERT(
              index >= 0 && index < static_cast<int64_t>(m_shape[0]),
              "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
              index,
              m_shape[0]);
            ArrayView<T> view(m_ref);
            const auto stride = Stride(m_shape);
            view.setShape(m_shape.subshape(1, ndim()));
            if (ndim() == 1)
                view.setStride(Stride({1}));
            else
                view.setStride(stride.subshape(1, ndim()));
            view.setOffset(m_offset + index * stride[0]);
            return view;
        }

        template<typename T>
        auto ArrayView<T>::operator[](int64_t index) -> ArrayView<T> {
            LIBRAPID_ASSERT(
              index >= 0 && index < static_cast<int64_t>(m_shape[0]),
              "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
              index,
              m_shape[0]);
            ArrayView<T> view(m_ref);
            const auto stride = Stride(m_shape);
            view.setShape(m_shape.subshape(1, ndim()));
            if (ndim() == 1)
                view.setStride(Stride({1}));
            else
                view.setStride(stride.subshape(1, ndim()));
            view.setOffset(m_offset + index * stride[0]);
            return view;
        }

        template<typename T>
        template<typename CAST>
        CAST ArrayView<T>::get() const {
            LIBRAPID_ASSERT(m_shape.ndim() == 0,
                            "Can only cast a scalar ArrayView to a salar object");
            return scalar(0);
        }

        template<typename T>
        template<typename CAST>
        ArrayView<T>::operator CAST() const {
            return get();
        }

        template<typename T>
        auto ArrayView<T>::shape() const -> ShapeType {
            return m_shape;
        }

        template<typename T>
        auto ArrayView<T>::stride() const -> StrideType {
            return m_stride;
        }

        template<typename T>
        auto ArrayView<T>::offset() const -> int64_t {
            return m_offset;
        }

        template<typename T>
        void ArrayView<T>::setShape(const ShapeType &shape) {
            m_shape = shape;
        }

        template<typename T>
        void ArrayView<T>::setStride(const StrideType &stride) {
            m_stride = stride;
        }

        template<typename T>
        void ArrayView<T>::setOffset(const int64_t &offset) {
            m_offset = offset;
        }

        template<typename T>
        auto ArrayView<T>::ndim() const -> int64_t {
            return m_shape.ndim();
        }

        template<typename T>
        auto ArrayView<T>::scalar(int64_t index) const -> auto {
            if (ndim() == 0) return m_ref.scalar(m_offset);

            ShapeType tmp   = ShapeType::zeros(ndim());
            tmp[ndim() - 1] = index % m_shape[ndim() - 1];
            for (int64_t i = ndim() - 2; i >= 0; --i) {
                index /= m_shape[i + 1];
                tmp[i] = index % m_shape[i];
            }
            int64_t offset = 0;
            for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; }
            return m_ref.scalar(m_offset + offset);
        }

        template<typename T>
        auto ArrayView<T>::eval() const -> ArrayType {
            ArrayType res(m_shape);
            ShapeType coord = ShapeType::zeros(m_shape.ndim());
            int64_t d = 0, p = 0;
            int64_t idim = 0, adim = 0;
            const int64_t ndim = m_shape.ndim();

            do {
                res.storage()[d++] = m_ref.scalar(p + m_offset);

                for (idim = 0; idim < ndim; ++idim) {
                    adim = ndim - idim - 1;
                    if (++coord[adim] == m_shape[adim]) {
                        coord[adim] = 0;
                        p           = p - (m_shape[adim] - 1) * m_stride[adim];
                    } else {
                        p = p + m_stride[adim];
                        break;
                    }
                }
            } while (idim < ndim);

            return res;
        }

        template<typename T>
        auto ArrayView<T>::begin() const -> Iterator {
            return Iterator(*this, 0);
        }

        template<typename T>
        auto ArrayView<T>::end() const -> Iterator {
            return Iterator(*this, m_shape[0]);
        }
    } // namespace array
} // namespace librapid

// Support FMT printing
#ifdef FMT_API
LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::ArrayView<T>)
LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::ArrayView<T>)
#endif // FMT_API

#endif // LIBRAPID_ARRAY_ARRAY_VIEW_HPP