Program Listing for File function.hpp#
↰ Return to documentation for file (librapid/include/librapid/array/function.hpp)
#ifndef LIBRAPID_ARRAY_FUNCTION_HPP
#define LIBRAPID_ARRAY_FUNCTION_HPP
namespace librapid {
namespace typetraits {
// Extract allowVectorisation from the input types
template<typename First, typename... T>
constexpr bool checkAllowVectorisation() {
if constexpr (sizeof...(T) == 0) {
return TypeInfo<std::decay_t<First>>::allowVectorisation;
} else {
using T1 = typename TypeInfo<std::decay_t<First>>::Scalar;
return TypeInfo<std::decay_t<First>>::allowVectorisation &&
checkAllowVectorisation<T...>() &&
(std::is_same_v<T1, typename TypeInfo<std::decay_t<T>>::Scalar> && ...);
}
}
template<typename First, typename... Rest>
constexpr auto commonBackend() {
using FirstBackend = typename TypeInfo<std::decay_t<First>>::Backend;
if constexpr (sizeof...(Rest) == 0) {
return FirstBackend {};
} else {
using RestBackend = decltype(commonBackend<Rest...>());
if constexpr (std::is_same_v<FirstBackend, backend::OpenCLIfAvailable> ||
std::is_same_v<RestBackend, backend::OpenCLIfAvailable>) {
return backend::OpenCLIfAvailable {};
} else if constexpr (std::is_same_v<FirstBackend, backend::CUDAIfAvailable> ||
std::is_same_v<RestBackend, backend::CUDAIfAvailable>) {
return backend::CUDAIfAvailable {};
} else {
return backend::CPU {};
}
}
}
// Normally, we want to use the scalar type of the input. This said, there are a few edge
// cases where it is necessary to use the type itself.
// Default
template<typename T>
struct ScalarTypeHelper {
using Type = typename TypeInfo<std::decay_t<T>>::Scalar;
};
// Vectors
template<typename T, uint64_t NumDims>
struct ScalarTypeHelper<Vector<T, NumDims>> {
using Type = Vector<T, NumDims>;
};
// Once we have the correct scalar types, we need to check if the result is a lazy-evaluated
// function. If so, we need to extract the actual return type from the function.
// Default
template<typename T>
struct ReturnTypeHelper {
using Type = T;
};
// Binary vector operation
template<typename LHS, typename RHS, typename Op>
struct ReturnTypeHelper<vectorDetail::BinaryVecOp<LHS, RHS, Op>> {
using IntermediateType = vectorDetail::BinaryVecOp<LHS, RHS, Op>;
using Type = decltype(std::declval<IntermediateType>().eval());
};
// Unary vector operation
template<typename Val, typename Op>
struct ReturnTypeHelper<vectorDetail::UnaryVecOp<Val, Op>> {
using IntermediateType = vectorDetail::UnaryVecOp<Val, Op>;
using Type = decltype(std::declval<IntermediateType>().eval());
};
template<typename desc, typename Functor_, typename... Args>
struct TypeInfo<::librapid::detail::Function<desc, Functor_, Args...>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
// using Scalar = decltype(std::declval<Functor_>()(
// std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));
// using Scalar = decltype(std::declval<Functor_>()(
// std::declval<typename ScalarTypeHelper<Args>::Type>()...));
using TempScalar = decltype(std::declval<Functor_>()(
std::declval<typename ScalarTypeHelper<Args>::Type>()...));
using Scalar = typename ReturnTypeHelper<TempScalar>::Type;
using Packet = typename TypeInfo<Scalar>::Packet;
using Backend = decltype(commonBackend<Args...>());
using ShapeType =
typename detail::ShapeTypeHelper<typename TypeInfo<Args>::ShapeType...>::Type;
using ArrayType = Array<Scalar, Backend>;
using StorageType = typename TypeInfo<ArrayType>::StorageType;
static constexpr bool allowVectorisation = checkAllowVectorisation<Args...>();
static constexpr bool supportsArithmetic = TypeInfo<Scalar>::supportsArithmetic;
static constexpr bool supportsLogical = TypeInfo<Scalar>::supportsLogical;
static constexpr bool supportsBinary = TypeInfo<Scalar>::supportsBinary;
};
LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args,
::librapid::detail::Function<desc COMMA Functor_ COMMA Args...>);
} // namespace typetraits
namespace detail {
// Descriptor is defined in "forward.hpp"
template<typename Packet, typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t index) {
if constexpr (detail::IsArrayType<T>::val) {
static_assert(std::is_same_v<Packet, decltype(obj.packet(index))>,
"Packet types do not match");
return obj.packet(index);
} else {
return Packet(obj);
}
}
template<typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) {
if constexpr(detail::IsArrayType<T>::val) {
return obj.scalar(index);
} else {
return obj;
}
}
template<typename First, typename... Rest>
constexpr auto scalarTypesAreSame() {
if constexpr (sizeof...(Rest) == 0) {
using Scalar = typename typetraits::TypeInfo<std::decay_t<First>>::Scalar;
return Scalar {};
} else {
using RestType = decltype(scalarTypesAreSame<Rest...>());
if constexpr (std::is_same_v<
typename typetraits::TypeInfo<std::decay_t<First>>::Scalar,
RestType>) {
return RestType {};
} else {
return std::false_type {};
}
}
}
template<typename desc, typename Functor_, typename... Args>
class Function {
public:
using Type = Function<desc, Functor_, Args...>;
using Functor = Functor_;
using ShapeType = typename typetraits::TypeInfo<Type>::ShapeType;
using StrideType = ShapeType;
using Scalar = typename typetraits::TypeInfo<Type>::Scalar;
using Backend = typename typetraits::TypeInfo<Type>::Backend;
using Packet = typename typetraits::TypeInfo<Scalar>::Packet;
using Iterator = detail::ArrayIterator<Function>;
using Descriptor = desc;
static constexpr bool argsAreSameType =
!std::is_same_v<decltype(scalarTypesAreSame<Args...>()), std::false_type>;
Function() = default;
LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args);
LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default;
LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default;
LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default;
LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> size_t;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const -> const ShapeType &;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;
template<typename T, typename Char, size_t N, typename Ctx>
void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
const char (&formatString)[N], Ctx &ctx) const;
private:
template<size_t... I>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence<I...>,
size_t index) const;
template<size_t... I>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence<I...>,
size_t index) const;
Functor m_functor;
std::tuple<Args...> m_args;
ShapeType m_shape;
size_t m_size = 0;
};
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE Function<desc, Functor, Args...>::Function(Functor &&functor,
Args &&...args) :
m_functor(std::forward<Functor>(functor)),
m_args(std::forward<Args>(args)...),
m_shape(typetraits::TypeInfo<Functor>::getShape(m_args)), m_size(m_shape.size()) {}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::shape() const
-> const ShapeType & {
return m_shape;
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::size() const -> size_t {
return m_size;
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::ndim() const -> size_t {
return m_shape.ndim();
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto &Function<desc, Functor, Args...>::args() const {
return m_args;
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto
Function<desc, Functor, Args...>::operator[](int64_t index) const {
return createGeneralArrayView(*this)[index];
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::eval() const {
auto res = Array<Scalar, Backend>(shape());
res = *this;
return res;
}
template<typename desc, typename Functor, typename... Args>
typename Function<desc, Functor, Args...>::Packet LIBRAPID_ALWAYS_INLINE
Function<desc, Functor, Args...>::packet(size_t index) const {
return packetImpl(std::make_index_sequence<sizeof...(Args)>(), index);
}
template<typename desc, typename Functor, typename... Args>
template<size_t... I>
LIBRAPID_ALWAYS_INLINE auto
Function<desc, Functor, Args...>::packetImpl(std::index_sequence<I...>, size_t index) const
-> Packet {
return m_functor.packet(packetExtractor<Packet>(std::get<I>(m_args), index)...);
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::scalar(size_t index) const
-> Scalar {
return scalarImpl(std::make_index_sequence<sizeof...(Args)>(), index);
}
template<typename desc, typename Functor, typename... Args>
template<size_t... I>
LIBRAPID_ALWAYS_INLINE auto
Function<desc, Functor, Args...>::scalarImpl(std::index_sequence<I...>, size_t index) const
-> Scalar {
return m_functor(scalarExtractor(std::get<I>(m_args), index)...);
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::begin() const -> Iterator {
return Iterator(*this, 0);
}
template<typename desc, typename Functor, typename... Args>
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::end() const -> Iterator {
return Iterator(*this, shape()[0]);
}
template<typename desc, typename Functor, typename... Args>
template<typename T, typename Char, size_t N, typename Ctx>
LIBRAPID_ALWAYS_INLINE void
Function<desc, Functor, Args...>::str(const fmt::formatter<T, Char> &format, char bracket,
char separator, const char (&formatString)[N],
Ctx &ctx) const {
createGeneralArrayView(*this).str(format, bracket, separator, formatString, ctx);
}
} // namespace detail
} // namespace librapid
// Support FMT printing
ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args,
librapid::detail::Function<desc COMMA Functor COMMA Args...>)
LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args,
librapid::detail::Function<desc COMMA Functor COMMA Args...>)
#endif // LIBRAPID_ARRAY_FUNCTION_HPP