Program Listing for File function.hpp#

Return to documentation for file (librapid/include/librapid/array/function.hpp)

#ifndef LIBRAPID_ARRAY_FUNCTION_HPP
#define LIBRAPID_ARRAY_FUNCTION_HPP

namespace librapid {
    namespace typetraits {
        // Extract allowVectorisation from the input types
        template<typename First, typename... T>
        constexpr bool checkAllowVectorisation() {
            if constexpr (sizeof...(T) == 0) {
                return TypeInfo<std::decay_t<First>>::allowVectorisation;
            } else {
                using T1 = typename TypeInfo<std::decay_t<First>>::Scalar;
                return TypeInfo<std::decay_t<First>>::allowVectorisation &&
                       checkAllowVectorisation<T...>() &&
                       (std::is_same_v<T1, typename TypeInfo<std::decay_t<T>>::Scalar> && ...);
            }
        }

        template<typename First, typename... Rest>
        constexpr auto commonBackend() {
            using FirstBackend = typename TypeInfo<std::decay_t<First>>::Backend;
            if constexpr (sizeof...(Rest) == 0) {
                return FirstBackend {};
            } else {
                using RestBackend = decltype(commonBackend<Rest...>());
                if constexpr (std::is_same_v<FirstBackend, backend::OpenCLIfAvailable> ||
                              std::is_same_v<RestBackend, backend::OpenCLIfAvailable>) {
                    return backend::OpenCLIfAvailable {};
                } else if constexpr (std::is_same_v<FirstBackend, backend::CUDAIfAvailable> ||
                                     std::is_same_v<RestBackend, backend::CUDAIfAvailable>) {
                    return backend::CUDAIfAvailable {};
                } else {
                    return backend::CPU {};
                }
            }
        }

        // Normally, we want to use the scalar type of the input. This said, there are a few edge
        // cases where it is necessary to use the type itself.

        // Default
        template<typename T>
        struct ScalarTypeHelper {
            using Type = typename TypeInfo<std::decay_t<T>>::Scalar;
        };

        // Vectors
        template<typename T, uint64_t NumDims>
        struct ScalarTypeHelper<Vector<T, NumDims>> {
            using Type = Vector<T, NumDims>;
        };

        // Once we have the correct scalar types, we need to check if the result is a lazy-evaluated
        // function. If so, we need to extract the actual return type from the function.

        // Default
        template<typename T>
        struct ReturnTypeHelper {
            using Type = T;
        };

        // Binary vector operation
        template<typename LHS, typename RHS, typename Op>
        struct ReturnTypeHelper<vectorDetail::BinaryVecOp<LHS, RHS, Op>> {
            using IntermediateType = vectorDetail::BinaryVecOp<LHS, RHS, Op>;
            using Type             = decltype(std::declval<IntermediateType>().eval());
        };

        // Unary vector operation
        template<typename Val, typename Op>
        struct ReturnTypeHelper<vectorDetail::UnaryVecOp<Val, Op>> {
            using IntermediateType = vectorDetail::UnaryVecOp<Val, Op>;
            using Type             = decltype(std::declval<IntermediateType>().eval());
        };

        template<typename desc, typename Functor_, typename... Args>
        struct TypeInfo<::librapid::detail::Function<desc, Functor_, Args...>> {
            static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
            // using Scalar = decltype(std::declval<Functor_>()(
            //      std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));

            // using Scalar = decltype(std::declval<Functor_>()(
            //   std::declval<typename ScalarTypeHelper<Args>::Type>()...));

            using TempScalar = decltype(std::declval<Functor_>()(
              std::declval<typename ScalarTypeHelper<Args>::Type>()...));

            using Scalar = typename ReturnTypeHelper<TempScalar>::Type;

            using Packet  = typename TypeInfo<Scalar>::Packet;
            using Backend = decltype(commonBackend<Args...>());
            using ShapeType =
              typename detail::ShapeTypeHelper<typename TypeInfo<Args>::ShapeType...>::Type;

            using ArrayType   = Array<Scalar, Backend>;
            using StorageType = typename TypeInfo<ArrayType>::StorageType;

            static constexpr bool allowVectorisation = checkAllowVectorisation<Args...>();

            static constexpr bool supportsArithmetic = TypeInfo<Scalar>::supportsArithmetic;
            static constexpr bool supportsLogical    = TypeInfo<Scalar>::supportsLogical;
            static constexpr bool supportsBinary     = TypeInfo<Scalar>::supportsBinary;
        };

        LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args,
                                ::librapid::detail::Function<desc COMMA Functor_ COMMA Args...>);
    } // namespace typetraits

    namespace detail {
        // Descriptor is defined in "forward.hpp"

        template<typename Packet, typename T>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t index) {
            if constexpr (detail::IsArrayType<T>::val) {
                static_assert(std::is_same_v<Packet, decltype(obj.packet(index))>,
                              "Packet types do not match");
                return obj.packet(index);
            } else {
                return Packet(obj);
            }
        }

        template<typename T>
        LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) {
            if constexpr(detail::IsArrayType<T>::val) {
                return obj.scalar(index);
            } else {
                return obj;
            }
        }

        template<typename First, typename... Rest>
        constexpr auto scalarTypesAreSame() {
            if constexpr (sizeof...(Rest) == 0) {
                using Scalar = typename typetraits::TypeInfo<std::decay_t<First>>::Scalar;
                return Scalar {};
            } else {
                using RestType = decltype(scalarTypesAreSame<Rest...>());
                if constexpr (std::is_same_v<
                                typename typetraits::TypeInfo<std::decay_t<First>>::Scalar,
                                RestType>) {
                    return RestType {};
                } else {
                    return std::false_type {};
                }
            }
        }

        template<typename desc, typename Functor_, typename... Args>
        class Function {
        public:
            using Type       = Function<desc, Functor_, Args...>;
            using Functor    = Functor_;
            using ShapeType  = typename typetraits::TypeInfo<Type>::ShapeType;
            using StrideType = ShapeType;
            using Scalar     = typename typetraits::TypeInfo<Type>::Scalar;
            using Backend    = typename typetraits::TypeInfo<Type>::Backend;
            using Packet     = typename typetraits::TypeInfo<Scalar>::Packet;
            using Iterator   = detail::ArrayIterator<Function>;

            using Descriptor = desc;
            static constexpr bool argsAreSameType =
              !std::is_same_v<decltype(scalarTypesAreSame<Args...>()), std::false_type>;

            Function() = default;

            LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args);

            LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default;

            LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default;

            LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default;

            LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> size_t;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const -> const ShapeType &;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const;

            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;

            template<typename T, typename Char, size_t N, typename Ctx>
            void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
                     const char (&formatString)[N], Ctx &ctx) const;

        private:
            template<size_t... I>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence<I...>,
                                                                        size_t index) const;

            template<size_t... I>
            LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence<I...>,
                                                                        size_t index) const;

            Functor m_functor;
            std::tuple<Args...> m_args;
            ShapeType m_shape;
            size_t m_size = 0;
        };

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE Function<desc, Functor, Args...>::Function(Functor &&functor,
                                                                          Args &&...args) :
                m_functor(std::forward<Functor>(functor)),
                m_args(std::forward<Args>(args)...),
                m_shape(typetraits::TypeInfo<Functor>::getShape(m_args)), m_size(m_shape.size()) {}

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::shape() const
          -> const ShapeType & {
            return m_shape;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::size() const -> size_t {
            return m_size;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::ndim() const -> size_t {
            return m_shape.ndim();
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto &Function<desc, Functor, Args...>::args() const {
            return m_args;
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::operator[](int64_t index) const {
            return createGeneralArrayView(*this)[index];
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::eval() const {
            auto res = Array<Scalar, Backend>(shape());
            res      = *this;
            return res;
        }

        template<typename desc, typename Functor, typename... Args>
        typename Function<desc, Functor, Args...>::Packet LIBRAPID_ALWAYS_INLINE
        Function<desc, Functor, Args...>::packet(size_t index) const {
            return packetImpl(std::make_index_sequence<sizeof...(Args)>(), index);
        }

        template<typename desc, typename Functor, typename... Args>
        template<size_t... I>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::packetImpl(std::index_sequence<I...>, size_t index) const
          -> Packet {
            return m_functor.packet(packetExtractor<Packet>(std::get<I>(m_args), index)...);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::scalar(size_t index) const
          -> Scalar {
            return scalarImpl(std::make_index_sequence<sizeof...(Args)>(), index);
        }

        template<typename desc, typename Functor, typename... Args>
        template<size_t... I>
        LIBRAPID_ALWAYS_INLINE auto
        Function<desc, Functor, Args...>::scalarImpl(std::index_sequence<I...>, size_t index) const
          -> Scalar {
            return m_functor(scalarExtractor(std::get<I>(m_args), index)...);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::begin() const -> Iterator {
            return Iterator(*this, 0);
        }

        template<typename desc, typename Functor, typename... Args>
        LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::end() const -> Iterator {
            return Iterator(*this, shape()[0]);
        }

        template<typename desc, typename Functor, typename... Args>
        template<typename T, typename Char, size_t N, typename Ctx>
        LIBRAPID_ALWAYS_INLINE void
        Function<desc, Functor, Args...>::str(const fmt::formatter<T, Char> &format, char bracket,
                                              char separator, const char (&formatString)[N],
                                              Ctx &ctx) const {
            createGeneralArrayView(*this).str(format, bracket, separator, formatString, ctx);
        }
    } // namespace detail
} // namespace librapid

// Support FMT printing
ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args,
                   librapid::detail::Function<desc COMMA Functor COMMA Args...>)
LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args,
                           librapid::detail::Function<desc COMMA Functor COMMA Args...>)

#endif // LIBRAPID_ARRAY_FUNCTION_HPP