Program Listing for File strideTools.hpp#

Return to documentation for file (librapid/include/librapid/array/strideTools.hpp)

#ifndef LIBRAPID_ARRAY_STRIDE_TOOLS_HPP
#define LIBRAPID_ARRAY_STRIDE_TOOLS_HPP

namespace librapid {
    namespace typetraits {
        LIBRAPID_DEFINE_AS_TYPE(typename ShapeType, Stride<ShapeType>);
    }

    // class Stride : public Shape {
    // public:
    //  /// Default Constructor
    //  LIBRAPID_ALWAYS_INLINE Stride() = default;

    //  /// Construct a Stride from a Shape object. This will assume that the data represented by
    //  /// the Shape object is a contiguous block of memory, and will calculate the corresponding
    //  /// strides based on this.
    //  /// \param shape
    //  LIBRAPID_ALWAYS_INLINE Stride(const Shape &shape);

    //  /// Copy a Stride object
    //  /// \param other The Stride object to copy.
    //  LIBRAPID_ALWAYS_INLINE Stride(const Stride &other) = default;

    //  /// Move a Stride object
    //  /// \param other The Stride object to move.
    //  LIBRAPID_ALWAYS_INLINE Stride(Stride &&other) noexcept = default;

    //  /// Assign a Stride object to this Stride object.
    //  /// \param other The Stride object to assign.
    //  LIBRAPID_ALWAYS_INLINE Stride &operator=(const Stride &other) = default;

    //  /// Move a Stride object to this Stride object.
    //  /// \param other The Stride object to move.
    //  LIBRAPID_ALWAYS_INLINE Stride &operator=(Stride &&other) noexcept = default;
    // };

    // LIBRAPID_ALWAYS_INLINE Stride::Stride(const Shape &shape) : Shape(shape) {
    //  if (this->m_dims == 0) {
    //      // Edge case for a zero-dimensional array
    //      this->m_data[0] = 1;
    //      return;
    //  }

    //  uint32_t tmp[MaxDimensions] {0};
    //  tmp[this->m_dims - 1] = 1;
    //  for (size_t i = this->m_dims - 1; i > 0; --i) tmp[i - 1] = tmp[i] * this->m_data[i];
    //  for (size_t i = 0; i < this->m_dims; ++i) this->m_data[i] = tmp[i];
    // }

    template<typename ShapeType_>
    class Stride {
    public:
        using ShapeType = ShapeType_;
        using IndexType = typename std::decay_t<decltype(std::declval<ShapeType>()[0])>;
        static constexpr size_t MaxDimensions = ShapeType::MaxDimensions;

        LIBRAPID_ALWAYS_INLINE Stride() = default;
        LIBRAPID_ALWAYS_INLINE Stride(const ShapeType &shape);
        LIBRAPID_ALWAYS_INLINE Stride(const Stride &other)                = default;
        LIBRAPID_ALWAYS_INLINE Stride(Stride &&other) noexcept            = default;
        LIBRAPID_ALWAYS_INLINE Stride &operator=(const Stride &other)     = default;
        LIBRAPID_ALWAYS_INLINE Stride &operator=(Stride &&other) noexcept = default;

        LIBRAPID_ALWAYS_INLINE auto operator[](size_t index) const -> IndexType;
        LIBRAPID_ALWAYS_INLINE auto operator[](size_t index) -> IndexType &;

        LIBRAPID_ALWAYS_INLINE auto ndim() const { return m_data.ndim(); }
        LIBRAPID_ALWAYS_INLINE auto substride(size_t start, size_t end) const -> Stride<Shape>;

        LIBRAPID_ALWAYS_INLINE auto data() const -> const ShapeType &;
        LIBRAPID_ALWAYS_INLINE auto data() -> ShapeType &;

        template<typename T_, typename Char, typename Ctx>
        LIBRAPID_ALWAYS_INLINE void str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const;

    protected:
        ShapeType m_data;
    };

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE Stride<ShapeType>::Stride(const ShapeType &shape) : m_data(shape) {
        if (this->m_data.size() == 0) {
            // Edge case for a zero-dimensional array
            this->m_data[0] = 1;
            return;
        }

        uint32_t tmp[MaxDimensions] {0};
        tmp[shape.ndim() - 1] = 1;
        for (size_t i = shape.ndim() - 1; i > 0; --i) tmp[i - 1] = tmp[i] * this->m_data[i];
        for (size_t i = 0; i < shape.ndim(); ++i) this->m_data[i] = tmp[i];
    }

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE auto Stride<ShapeType>::operator[](size_t index) const -> IndexType {
        return this->m_data[index];
    }

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE auto Stride<ShapeType>::operator[](size_t index) -> IndexType & {
        return this->m_data[index];
    }

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE auto Stride<ShapeType>::substride(size_t start, size_t end) const
      -> Stride<Shape> {
        LIBRAPID_ASSERT(start < end, "Start index must be less than end index");
        LIBRAPID_ASSERT(end <= this->m_data.ndim(), "End index must be less than ndim()");

        Stride<Shape> res;
        res.data() = data().subshape(start, end);
        return res;
    }

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE auto Stride<ShapeType>::data() const -> const ShapeType & {
        return this->m_data;
    }

    template<typename ShapeType>
    LIBRAPID_ALWAYS_INLINE auto Stride<ShapeType>::data() -> ShapeType & {
        return this->m_data;
    }

    template<typename ShapeType>
    template<typename T_, typename Char, typename Ctx>
    LIBRAPID_ALWAYS_INLINE void Stride<ShapeType>::str(const fmt::formatter<T_, Char> &format,
                                                       Ctx &ctx) const {
        fmt::format_to(ctx.out(), "Stride(");
        for (size_t i = 0; i < m_data.ndim(); ++i) {
            format.format(m_data[i], ctx);
            if (i != m_data.ndim() - 1) fmt::format_to(ctx.out(), ", ");
        }
        fmt::format_to(ctx.out(), ")");
    }
} // namespace librapid

// Support FMT printing
template<typename T>
struct fmt::formatter<librapid::Stride<T>> : fmt::formatter<librapid::Shape> {
    template<typename FormatContext>
    auto format(const librapid::Stride<T> &stride, FormatContext &ctx) {
        return fmt::formatter<librapid::Shape>::format(stride, ctx);
    }
};

#endif // LIBRAPID_ARRAY_STRIDE_TOOLS_HPP