Program Listing for File function.hpp#

Return to documentation for file (librapid/include/librapid/array/function.hpp)

#ifndef LIBRAPID_ARRAY_FUNCTION_HPP
#define LIBRAPID_ARRAY_FUNCTION_HPP

namespace librapid {
    namespace typetraits {
        // Extract allowVectorisation from the input types
        template<typename First, typename... T>
        constexpr bool checkAllowVectorisation() {
            if constexpr (sizeof...(T) == 0) {
                return TypeInfo<std::decay_t<First>>::allowVectorisation;
            } else {
                using T1 = typename TypeInfo<std::decay_t<First>>::Scalar;
                return TypeInfo<std::decay_t<First>>::allowVectorisation &&
                       checkAllowVectorisation<T...>() &&
                       (std::is_same_v<T1, typename TypeInfo<std::decay_t<T>>::Scalar> && ...);
            }
        }

        template<typename First, typename... Rest>
        constexpr auto commonBackend() {
            using FirstBackend = typename TypeInfo<std::decay_t<First>>::Backend;
            if constexpr (sizeof...(Rest) == 0) {
                return FirstBackend {};
            } else {
                using RestBackend = decltype(commonBackend<Rest...>());
                if constexpr (std::is_same_v<FirstBackend, backend::OpenCLIfAvailable> ||
                              std::is_same_v<RestBackend, backend::OpenCLIfAvailable>) {
                    return backend::OpenCLIfAvailable {};
                } else if constexpr (std::is_same_v<FirstBackend, backend::CUDAIfAvailable> ||
                                     std::is_same_v<RestBackend, backend::CUDAIfAvailable>) {
                    return backend::CUDAIfAvailable {};
                } else {
                    return backend::CPU {};
                }
            }
        }

        template<typename desc, typename Functor_, typename... Args>
        struct TypeInfo<::librapid::detail::Function<desc, Functor_, Args...>> {
            static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
            using Scalar                               = decltype(std::declval<Functor_>()(
                std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));
            using Packet                               = typename TypeInfo<Scalar>::Packet;
            using Backend                              = decltype(commonBackend<Args...>());
            using ShapeType =
              typename detail::ShapeTypeHelper<typename TypeInfo<Args>::ShapeType...>::Type;

            using ArrayType   = Array<Scalar, Backend>;
            using StorageType = typename TypeInfo<ArrayType>::StorageType;

            static constexpr bool allowVectorisation = checkAllowVectorisation<Args...>();

            static constexpr bool supportsArithmetic = TypeInfo<Scalar>::supportsArithmetic;
            static constexpr bool supportsLogical    = TypeInfo<Scalar>::supportsLogical;
            static constexpr bool supportsBinary     = TypeInfo<Scalar>::supportsBinary;
        };

        LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args,
                                ::librapid::detail::Function<desc COMMA Functor_ COMMA Args...>);
    } // namespace typetraits

    namespace detail {
        // Descriptor is defined in "forward.hpp"

        template<
          typename Packet, typename T,
          typename std::enable_if_t<
            typetraits::TypeInfo<T>::type != ::librapid::detail::LibRapidType::Scalar, int> = 0>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj,
                                                                         size_t index) {
            static_assert(std::is_same_v<Packet, decltype(obj.packet(index))>,
                          "Packet types do not match");
            return obj.packet(index);
        }

        template<
          typename Packet, typename T,
          typename std::enable_if_t<
            typetraits::TypeInfo<T>::type == ::librapid::detail::LibRapidType::Scalar, int> = 0>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t) {
            return Packet(obj);
        }

        template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type !=
                                                         ::librapid::detail::LibRapidType::Scalar,
                                                       int> = 0>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) {
            return obj.scalar(index);
        }

        template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type ==
                                                         ::librapid::detail::LibRapidType::Scalar,
                                                       int> = 0>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) {
            return obj;
        }

        template<typename First, typename... Rest>
        constexpr auto scalarTypesAreSame() {
            if constexpr (sizeof...(Rest) == 0) {
                using Scalar = typename typetraits::TypeInfo<std::decay_t<First>>::Scalar;
                return Scalar {};
            } else {
                using RestType = decltype(scalarTypesAreSame<Rest...>());
                if constexpr (std::is_same_v<
                                typename typetraits::TypeInfo<std::decay_t<First>>::Scalar,
                                RestType>) {
                    return RestType {};
                } else {
                    return std::false_type {};
                }
            }
        }

        template<typename desc, typename Functor_, typename... Args>
        class Function {
        public:
            using Type       = Function<desc, Functor_, Args...>;
            using Functor    = Functor_;
            using ShapeType  = typename typetraits::TypeInfo<Type>::ShapeType;
            using StrideType = ShapeType;
            using Scalar     = typename typetraits::TypeInfo<Type>::Scalar;
            using Backend    = typename typetraits::TypeInfo<Type>::Backend;
            using Packet     = typename typetraits::TypeInfo<Scalar>::Packet;
            using Iterator   = detail::ArrayIterator<Function>;

            using Descriptor = desc;
            static constexpr bool argsAreSameType =
              !std::is_same_v<decltype(scalarTypesAreSame<Args...>()), std::false_type>;

            Function() = default;

            LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args);

            LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default;

            LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default;

            LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default;

            LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> size_t;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const -> const ShapeType &;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;

            template<typename T, typename Char, typename Ctx>
            void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
                     Ctx &ctx) const;

        private:
            template<size_t... I>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence<I...>,
                                                                        size_t index) const;

            template<size_t... I>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence<I...>,
                                                                        size_t index) const;

            Functor m_functor;
            std::tuple<Args...> m_args;
            ShapeType m_shape;
            size_t m_size = 0;
        };

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE Function<desc, Functor, Args...>::Function(Functor &&functor,
                                                                          Args &&...args) :
                m_functor(std::forward<Functor>(functor)),
                m_args(std::forward<Args>(args)...),
                m_shape(typetraits::TypeInfo<Functor>::getShape(m_args)), m_size(m_shape.size()) {}

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::shape() const
          -> const ShapeType & {
            return m_shape;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::size() const -> size_t {
            return m_size;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::ndim() const -> size_t {
            return m_shape.ndim();
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto &Function<desc, Functor, Args...>::args() const {
            return m_args;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::operator[](int64_t index) const {
            return createGeneralArrayView(*this)[index];
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::eval() const {
            auto res = Array<Scalar, Backend>(shape());
            res      = *this;
            return res;
        }

        template<typename desc, typename Functor, typename... Args>
        typename Function<desc, Functor, Args...>::Packet LIBRAPID_ALWAYS_INLINE
        Function<desc, Functor, Args...>::packet(size_t index) const {
            return packetImpl(std::make_index_sequence<sizeof...(Args)>(), index);
        }

        template<typename desc, typename Functor, typename... Args>
        template<size_t... I>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::packetImpl(std::index_sequence<I...>, size_t index) const
          -> Packet {
            return m_functor.packet(packetExtractor<Packet>(std::get<I>(m_args), index)...);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::scalar(size_t index) const
          -> Scalar {
            return scalarImpl(std::make_index_sequence<sizeof...(Args)>(), index);
        }

        template<typename desc, typename Functor, typename... Args>
        template<size_t... I>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::scalarImpl(std::index_sequence<I...>, size_t index) const
          -> Scalar {
            return m_functor(scalarExtractor(std::get<I>(m_args), index)...);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::begin() const -> Iterator {
            return Iterator(*this, 0);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::end() const -> Iterator {
            return Iterator(*this, shape()[0]);
        }

        template<typename desc, typename Functor, typename... Args>
        template<typename T, typename Char, typename Ctx>
        LIBRAPID_ALWAYS_INLINE void
        Function<desc, Functor, Args...>::str(const fmt::formatter<T, Char> &format, char bracket,
                                              char separator, Ctx &ctx) const {
            createGeneralArrayView(*this).str(format, bracket, separator, ctx);
        }
    } // namespace detail
} // namespace librapid

// Support FMT printing
ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args,
                   librapid::detail::Function<desc COMMA Functor COMMA Args...>)
LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args,
                           librapid::detail::Function<desc COMMA Functor COMMA Args...>)

#endif // LIBRAPID_ARRAY_FUNCTION_HPP