Program Listing for File kernel_header.h#

Return to documentation for file (librapid/include/librapid/cuda/kernel_header.h)

#pragma once

#include <string>

namespace librapid::imp {
    inline const jitify::detail::vector<std::string> cudaHeaders = { // CUDA_INCLUDE_DIRS,
      CUDA_INCLUDE_DIRS + std::string("/curand.h"),
      CUDA_INCLUDE_DIRS + std::string("/curand_kernel.h"),
      CUDA_INCLUDE_DIRS + std::string("/cublas_v2.h"),
      CUDA_INCLUDE_DIRS + std::string("/cublas_api.h"),
      CUDA_INCLUDE_DIRS + std::string("/cuda_fp16.h"),
      CUDA_INCLUDE_DIRS + std::string("/cuda_bf16.h")};

    inline const std::vector<std::string> cudaParams = {
      "--disable-warnings", "-std=c++17", std::string("-I") + CUDA_INCLUDE_DIRS};

    inline std::string genKernelHeader() {
        return fmt::format(R"V0G0N(
#include <"{0}/curand_kernel.h>
#include <"{0}"/curand.h>
#include <stdint.h>
#include <type_traits>

#ifndef LIBRAPID_CUSTOM_COMPLEX
#define LIBRAPID_CUSTOM_COMPLEX

namespace librapid {{

    template<class T>
    class Complex {{
    public:
        Complex(const T &real_val = T(), const T &imag_val = T())
                : m_real(real_val), m_imag(imag_val) {{}}

        Complex &operator=(const T &val) {{
            m_real = val;
            m_imag = 0;
            return *this;
        }}

        template<class V>
        Complex(const Complex<V> &other)
                : Complex(static_cast<T>(other.real()), static_cast<T>(other.imag())) {{}}

        template<class V>
        Complex &operator=(const Complex<V> &other) {{
            m_real = static_cast<T>(other.real());
            m_imag = static_cast<T>(other.imag());
            return *this;
        }}

        Complex copy() const {{
            return Complex<T>(m_real, m_imag);
        }}

        inline Complex operator-() const {{
            return Complex<T>(-m_real, -m_imag);
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex operator+(const V &other) const {{
            return Complex<T>(m_real + other, m_imag);
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex operator-(const V &other) const {{
            return Complex<T>(m_real - other, m_imag);
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex operator*(const V &other) const {{
            return Complex<T>(m_real * other, m_imag * other);
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex operator/(const V &other) const {{
            return Complex<T>(m_real / other, m_imag / other);
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex &operator+=(const V &other) {{
            m_real += other;
            return *this;
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex &operator-=(const V &other) {{
            m_real -= other;
            return *this;
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex &operator*=(const V &other) {{
            m_real *= other;
            m_imag *= other;
            return *this;
        }}

        template<typename V, typename std::enable_if<std::is_scalar<V>::value, int>::type = 0>
        inline Complex &operator/=(const V &other) {{
            m_real /= other;
            m_imag /= other;
            return *this;
        }}

        template<typename V>
        inline Complex operator+(const Complex<V> &other) const {{
            return Complex(m_real + other.real(),
                           m_imag + other.imag());
        }}

        template<typename V>
        inline Complex operator-(const Complex<V> &other) const {{
            return Complex(m_real - other.real(),
                           m_imag - other.imag());
        }}

        template<typename V>
        inline Complex operator*(const Complex<V> &other) const {{
            return Complex((m_real * other.real()) - (m_imag * other.imag()),
                           (m_real * other.imag()) + (m_imag * other.real()));
        }}

        template<typename V>
        inline Complex operator/(const Complex<V> &other) const {{
            return Complex((m_real * other.real()) + (m_imag * other.imag()) /
                                                     ((other.real() * other.real()) + (other.imag() * other.imag())),
                           (m_real * other.real()) - (m_imag * other.imag()) /
                                                     ((other.real() * other.real()) + (other.imag() * other.imag())));
        }}

        template<typename V>
        inline Complex &operator+=(const Complex<V> &other) {{
            m_real = m_real + other.real();
            m_imag = m_imag + other.imag();
            return *this;
        }}

        template<typename V>
        inline Complex &operator-=(const Complex<V> &other) {{
            m_real = m_real - other.real();
            m_imag = m_imag - other.imag();
            return *this;
        }}

        template<typename V>
        inline Complex &operator*=(const Complex<V> &other) {{
            m_real = (m_real * other.real()) - (m_imag * other.imag());
            m_imag = (m_real * other.imag()) + (imag() * other.real());
            return *this;
        }}

        template<typename V>
        inline Complex &operator/=(const Complex<V> &other) {{
            m_real = (m_real * other.real()) + (m_imag * other.imag()) /
                                               ((other.real() * other.real()) + (other.imag() * other.imag()));
            m_imag = (m_real * other.real()) - (m_imag * other.imag()) /
                                               ((other.real() * other.real()) + (other.imag() * other.imag()));
            return *this;
        }}

        template<typename V>
        inline bool operator==(const Complex<V> &other) const {{
            return m_real == other.real() && m_imag == other.imag();
        }}

        template<typename V>
        inline bool operator!=(const Complex<V> &other) const {{
            return !(*this == other);
        }}

        template<typename V>
        inline bool operator==(const V &other) const {{
            return m_real == other && m_imag == 0;
        }}

        template<typename V>
        inline bool operator!=(const V &other) const {{
            return !(*this == other);
        }}

        inline T mag() const {{
            return std::sqrt(m_real * m_real + m_imag * m_imag);
        }}

        inline T angle() const {{
            return std::atan2(m_real, m_imag);
        }}

        inline Complex<T> log() const {{
            return Complex<T>(std::log(mag()), angle());
        }}

        inline Complex<T> conjugate() const {{
            return Complex<T>(m_real, -m_imag);
        }}

        inline Complex<T> reciprocal() const {{
            return Complex<T>((m_real) / (m_real * m_real + m_imag * m_imag),
                              -(m_imag) / (m_real * m_real + m_imag * m_imag));
        }}

        inline const T &real() const {{
            return m_real;
        }}

        inline T &real() {{
            return m_real;
        }}

        inline const T &imag() const {{
            return m_imag;
        }}

        inline T &imag() {{
            return m_imag;
        }}

        inline explicit operator std::string() const {{
            return str();
        }}

        template<typename V>
        inline operator V() const {{
            return m_real;
        }}

        template<typename V>
        inline explicit operator std::complex<V>() const {{
            return std::complex<V>(m_real, m_imag);
        }}

    private:
        T m_real = 0;
        T m_imag = 0;
    }};

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    Complex<B> operator+(const A &a, const Complex<B> &b) {{
        return Complex<B>(a) + b;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline Complex<B> operator-(const A &a, const Complex<B> &b) {{
        return Complex<B>(a) - b;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline Complex<B> operator*(const A &a, const Complex<B> &b) {{
        return Complex<B>(a) * b;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline Complex<B> operator/(const A &a, const Complex<B> &b) {{
        return Complex<B>(a) / b;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline A &operator+=(A &a, const Complex<B> &b) {{
        a += b.real();
        return a;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline A &operator-=(A &a, const Complex<B> &b) {{
        a -= b.real();
        return a;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline A &operator*=(A &a, const Complex<B> &b) {{
        a *= b.real();
        return a;
    }}

    template<typename A, typename B, typename std::enable_if<std::is_scalar<A>::value, int>::type = 0>
    inline A &operator/=(A &a, const Complex<B> &b) {{
        a /= b.real();
        return a;
    }}

}}

#endif // LIBRAPID_CUSTOM_COMPLEX
        )V0G0N",
                           CUDA_INCLUDE_DIRS);
    }
} // namespace librapid::imp