//===----------------------------------------------------------------------===//
//
// Part of the CUDA Toolkit, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
#  error "CUDA synchronization primitives are only supported for sm_70 and up."
#endif

#ifndef _CUDA_BARRIER
#define _CUDA_BARRIER

#include "atomic"
#include "cstddef"

#include "detail/__config"

#include "detail/__pragma_push"

#include "detail/libcxx/include/barrier"

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template<thread_scope _Sco, class _CompletionF = std::__empty_completion>
class barrier : public std::__barrier_base<_CompletionF, _Sco> {
public:
    barrier() = default;

    barrier(const barrier &) = delete;
    barrier & operator=(const barrier &) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    barrier(std::ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
        : std::__barrier_base<_CompletionF, _Sco>(__expected, __completion) {
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected) {
        new (__b) barrier(__expected);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected, _CompletionF __completion) {
        new (__b) barrier(__expected, __completion);
    }
};

struct __block_scope_barrier_base {};

template<>
class barrier<thread_scope_block, std::__empty_completion> : public __block_scope_barrier_base {
    using __barrier_base = std::__barrier_base<std::__empty_completion, (int)thread_scope_block>;
    __barrier_base __barrier;

public:
    using arrival_token = typename __barrier_base::arrival_token;

    barrier() = default;

    barrier(const barrier &) = delete;
    barrier & operator=(const barrier &) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    barrier(std::ptrdiff_t __expected, std::__empty_completion __completion = std::__empty_completion()) {
        static_assert(offsetof(barrier<thread_scope_block>, __barrier) == 0, "fatal error: bad barrier layout");
        init(this, __expected, __completion);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    ~barrier() {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            asm volatile ("mbarrier.inval.b64 [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
#endif
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected, std::__empty_completion __completion = std::__empty_completion()) {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__b->__barrier)) {
            asm volatile ("mbarrier.init.shared.b64 [%0], %1;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__b->__barrier))),
                    "r"(static_cast<std::uint32_t>(__expected))
                : "memory");
        }
        else
#endif
        {
            new (&__b->__barrier) __barrier_base(__expected);
        }
    }

    _LIBCUDACXX_NODISCARD_ATTRIBUTE _LIBCUDACXX_INLINE_VISIBILITY
    arrival_token arrive(std::ptrdiff_t __update = 1)
    {
#if __CUDA_ARCH__
        if (__isShared(&__barrier)) {
            arrival_token __token;
#if __CUDA_ARCH__ >= 800
            if (__update > 1) {
                asm volatile ("mbarrier.arrive.noComplete.shared.b64 %0, [%1], %2;"
                    : "=l"(__token)
                    : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))),
                        "r"(static_cast<std::uint32_t>(__update - 1))
                    : "memory");
            }
            asm volatile ("mbarrier.arrive.shared.b64 %0, [%1];"
                : "=l"(__token)
                : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
#else
            unsigned int __activeA = __match_any_sync(__activemask(), __update);
            unsigned int __activeB = __match_any_sync(__activemask(), reinterpret_cast<std::uintptr_t>(&__barrier));
            unsigned int __active = __activeA & __activeB;
            int __inc = __popc(__active) * __update;

            unsigned __laneid;
            asm volatile ("mov.u32 %0, %laneid;" : "=r"(__laneid));
            int __leader = __ffs(__active) - 1;

            if(__leader == __laneid)
            {
                __token = __barrier.arrive(__inc);
            }
            __token = __shfl_sync(__active, __token, __leader);
#endif
            return __token;
        }
        else
#endif
        {
            return __barrier.arrive(__update);
        }
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void wait(arrival_token && __phase) const
    {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            unsigned __sleep_timer_ns = 1;

            while (true) {
                int __wait_complete = 0;
                asm volatile ("{\n\t"
                        ".reg .pred p;\n\t"
                        "mbarrier.test_wait.shared.b64 p, [%1], %2;\n\t"
                        "selp.b32 %0, 1, 0, p;\n\t"
                        "}"
                    : "=r"(__wait_complete)
                    : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))), "l"(__phase)
                    : "memory");
                if (__wait_complete) { break; }

                asm volatile ("nanosleep.u32 %0;"
                    :: "r"((unsigned)__sleep_timer_ns));

                if (__sleep_timer_ns < 1048576) { // 2^20
                    __sleep_timer_ns *= 2;
                }
            }
        }
        else
#endif
        {
            __barrier.wait(std::move(__phase));
        }
    }

    inline _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_wait()
    {
        wait(arrive());
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_drop()
    {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            asm volatile ("mbarrier.arrive_drop.shared.b64 _, [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
        else
#endif
        {
            __barrier.arrive_and_drop();
        }
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept
    {
        return (1 << 20) - 1;
    }
};

#if __CUDA_ARCH__ >= 800
template<std::size_t __alignment, bool __large = (__alignment > 16)>
struct __memcpy_async_impl {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size) {
        memcpy(__destination, __source, __total_size);
        return false;
    }
};

template<>
struct __memcpy_async_impl<4, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size) {
        for (std::size_t __offset = 0; __offset < __total_size; __offset += 4) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 4;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<>
struct __memcpy_async_impl<8, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size) {
        for (std::size_t __offset = 0; __offset < __total_size; __offset += 8) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 8;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<>
struct __memcpy_async_impl<16, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size) {
        for (std::size_t __offset = 0; __offset < __total_size; __offset += 16) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 16;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<std::size_t __alignment>
struct __memcpy_async_impl<__alignment, true> : public __memcpy_async_impl<16, false> { };
#endif

template<std::size_t __native_alignment>
_LIBCUDACXX_INLINE_VISIBILITY
void inline __memcpy_async(char * __destination, char const * __source, std::size_t __size, barrier<thread_scope_block> & __barrier) {
#if __CUDA_ARCH__ >= 800
    bool __is_async = __isShared(__destination) && __isGlobal(__source);

    if (__is_async) {
        if (__native_alignment < 4) {
            auto __source_address = reinterpret_cast<std::uintptr_t>(__source);
            auto __destination_address = reinterpret_cast<std::uintptr_t>(__destination);

            // Lowest bit set will tell us what the common alignment of the three values is.
            auto __alignment = __ffs(__source_address | __destination_address | __size);

            switch (__alignment) {
                default: __is_async = __memcpy_async_impl<16>::__copy(__destination, __source, __size); break;
                case 4: __is_async = __memcpy_async_impl<8>::__copy(__destination, __source, __size); break;
                case 3: __is_async = __memcpy_async_impl<4>::__copy(__destination, __source, __size); break;
                case 2: // fallthrough
                case 1: __is_async = __memcpy_async_impl<1>::__copy(__destination, __source, __size); break;
            }
        }
        else {
            __is_async = __memcpy_async_impl<__native_alignment>::__copy(__destination, __source, __size);
        }
    }
    else
#endif
    {
        memcpy(__destination, __source, __size);
    }

#if __CUDA_ARCH__ >= 800
    if (__is_async) {
        if (__isShared(&__barrier)) {
            asm volatile ("cp.async.mbarrier.arrive.shared.b64 [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
        else {
            asm volatile ("cp.async.wait_all;"
                ::: "memory");
        }
    }
#endif
}

template<class _Tp>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Tp * __destination, _Tp const * __source, std::size_t __size, barrier<thread_scope_block> & __barrier) {
    // When compiling with NVCC and GCC 4.8, certain user defined types that _are_ trivially copyable are
    // incorrectly classified as not trivially copyable. Remove this assertion to allow for their usage with
    // memcpy_async when compiling with GCC 4.8.
    // FIXME: remove the #if once GCC 4.8 is no longer supported.
#if !defined(_LIBCUDACXX_COMPILER_GCC) || _GNUC_VER > 408
    static_assert(std::is_trivially_copyable<_Tp>::value, "memcpy_async requires a trivially copyable type");
#endif

    __memcpy_async<alignof(_Tp)>(reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

_LIBCUDACXX_INLINE_VISIBILITY
void inline memcpy_async(void * __destination, void const * __source, std::size_t __size, barrier<thread_scope_block> & __barrier) {
    __memcpy_async<1>(reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

_LIBCUDACXX_END_NAMESPACE_CUDA

#include "detail/__pragma_pop"

#endif //_CUDA_BARRIER
