/*

Copyright (c) 2019, NVIDIA Corporation

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

*/

#include <cstddef>
#include <thread>
#include <vector>
#include "atomic_wait"

//#define __BARRIER_NO_BUTTERFLY
//#define __BARRIER_NO_WAIT
//#define __BARRIER_NO_SPECIALIZATION

struct EmptyCompletionF {
    inline void operator()() noexcept { }
};

#ifndef __BARRIER_NO_BUTTERFLY

extern thread_local size_t __barrier_favorite_hash;

template<class CompletionF = EmptyCompletionF>
class barrier {

    static constexpr size_t __max_steps = CHAR_BIT * sizeof(ptrdiff_t) - 1;

    using __phase_t = uint8_t;    

    struct alignas(64) __state_t {
        std::atomic<__phase_t> v[__max_steps + 1] = {0};
    };

    alignas(64) ptrdiff_t              expected;
    ptrdiff_t                          expected_steps;
    std::atomic<ptrdiff_t>             expected_adjustment;
    std::atomic<__phase_t>             phase;
    alignas(64) std::vector<__state_t> state;
    CompletionF                        completion;

    static constexpr __phase_t __next_phase(__phase_t old) { 
        return (old + 1) & 3; 
    }
    inline bool __try_get_id(size_t id, __phase_t old_phase) {
        return state[id].v[__max_steps].compare_exchange_strong(old_phase, 
                                                                __next_phase(old_phase), 
                                                                std::memory_order_relaxed);
    }
    inline size_t __get_id(__phase_t const old_phase, ptrdiff_t count) {
        for(auto i = 0; i < count; ++i) {
            auto id = __barrier_favorite_hash + i;
            if(id > count)
                id %= count;
            if(__builtin_expect(__try_get_id(id, old_phase), 1)) {
                if(__barrier_favorite_hash != id)
                    __barrier_favorite_hash = id;
                return id;
            }
        }
        return ~0ull;
    }
    static constexpr uint32_t __log2_floor(ptrdiff_t count) { 
        return count <= 1 ? 0 : 1 + __log2_floor(count >> 1); 
    }
    static constexpr uint32_t __log2_ceil(ptrdiff_t count) { 
        auto const t = __log2_floor(count);
        return count == (1 << t) ? t : t + 1;
    }

public:
    using arrival_token = std::tuple<std::atomic<__phase_t>&, __phase_t>;

    barrier(ptrdiff_t expected, CompletionF completion = CompletionF()) 
            : expected(expected), expected_steps(__log2_ceil(expected)), 
              expected_adjustment(0), phase(0),
              state(expected), completion(completion) { 
        assert(expected >= 0);
    }

    ~barrier() = default;

    barrier(barrier const&) = delete;
    barrier& operator=(barrier const&) = delete;

    [[nodiscard]] inline arrival_token arrive(ptrdiff_t update = 1) {

        size_t id = 0; // assume, for now
        auto const old_phase = phase.load(std::memory_order_relaxed);
        auto const count = expected;
        assert(count > 0);
        auto const steps = expected_steps;
        if(0 != steps) {
            id = __get_id(old_phase, count);
            assert(id != ~0ull);
            for(uint32_t k = 0;k < steps; ++k) {
                auto const index = steps - k - 1;
                state[(id + (1 << k)) % count].v[index].store(__next_phase(old_phase), std::memory_order_release);
                while(state[id].v[index].load(std::memory_order_acquire) == old_phase)
                    ;
            }
        }
        if(0 == id) {
            completion();
            expected += expected_adjustment.load(std::memory_order_relaxed);
            expected_steps = __log2_ceil(expected);
            expected_adjustment.store(0, std::memory_order_relaxed);
            phase.store(__next_phase(old_phase), std::memory_order_release);
        }
        return std::tie(phase, old_phase);
    }
    inline void wait(arrival_token&& token) const {
        auto const& current_phase = std::get<0>(token);
        auto const old_phase = std::get<1>(token);
        if(__builtin_expect(old_phase != current_phase.load(std::memory_order_acquire),1))
            return;
#ifndef __BARRIER_NO_WAIT
        using __clock = std::conditional<std::chrono::high_resolution_clock::is_steady, 
                                            std::chrono::high_resolution_clock,
                                            std::chrono::steady_clock>::type;
        auto const start = __clock::now();
#endif
        while (old_phase == current_phase.load(std::memory_order_acquire)) {
#ifndef __BARRIER_NO_WAIT
            auto const elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(__clock::now() - start);
            //auto const step = std::min(elapsed / 4 + std::chrono::nanoseconds(100), 
                                    //   &&std::chrono::nanoseconds(1500));
            const auto step = std::min(elapsed/4 + std::chrono::nanoseconds(100), std::chrono::nanoseconds(1500));
            if(step > std::chrono::nanoseconds(1000))
                std::this_thread::sleep_for(step);
            else if(step > std::chrono::nanoseconds(500))
#endif
                std::this_thread::yield();
        }
    }
	inline void arrive_and_wait() {
        wait(arrive());
	}
    inline void arrive_and_drop() {
        expected_adjustment.fetch_sub(1, std::memory_order_relaxed);
        (void)arrive();
    }
};

#else

template<class CompletionF = EmptyCompletionF>
class barrier {

    alignas(64) std::atomic<bool> phase;
    std::atomic<ptrdiff_t> expected, arrived;
    CompletionF completion;
public:
    using arrival_token = bool;

    barrier(ptrdiff_t expected, CompletionF completion = CompletionF()) 
        : phase(false), expected(expected), arrived(expected), completion(completion) { 
    }

    ~barrier() = default;

    barrier(barrier const&) = delete;
    barrier& operator=(barrier const&) = delete;

    [[nodiscard]] arrival_token arrive(ptrdiff_t update = 1) {
        auto const old_phase = phase.load(std::memory_order_relaxed);
        auto const result = arrived.fetch_sub(update, std::memory_order_acq_rel) - update;
        assert(result >= 0);
        auto const new_expected = expected.load(std::memory_order_relaxed);
        if(0 == result) {
            completion();
            arrived.store(new_expected, std::memory_order_relaxed);
            phase.store(!old_phase, std::memory_order_release);
#ifndef __BARRIER_NO_WAIT
            atomic_notify_all(&phase);
#endif
        }
        return old_phase;
    }
    void wait(arrival_token&& old_phase) const {
#ifndef __BARRIER_NO_WAIT
        atomic_wait_explicit(&phase, old_phase, std::memory_order_acquire);
#else
        while(old_phase == phase.load(std::memory_order_acquire))
            ;
#endif
    }
	void arrive_and_wait() {
        wait(arrive());
	}
    void arrive_and_drop() {
        expected.fetch_sub(1, std::memory_order_relaxed);
        (void)arrive();
    }
};

#ifndef __BARRIER_NO_SPECIALIZATION

template< >
class barrier<EmptyCompletionF> {

    static constexpr uint64_t expected_unit = 1ull;
    static constexpr uint64_t arrived_unit = 1ull << 32;
    static constexpr uint64_t expected_mask = arrived_unit - 1;
    static constexpr uint64_t phase_bit = 1ull << 63;
    static constexpr uint64_t arrived_mask = (phase_bit - 1) & ~expected_mask;

    alignas(64) std::atomic<uint64_t> phase_arrived_expected;

    static inline constexpr uint64_t __init(ptrdiff_t count) noexcept {
        uint64_t const comp = (1u << 31) - count;
        return (comp << 32) | comp;
    }

public:
    using arrival_token = uint64_t;

    barrier(ptrdiff_t count, EmptyCompletionF = EmptyCompletionF()) 
        : phase_arrived_expected(__init(count)) { 
    }

    ~barrier() = default;

    barrier(barrier const&) = delete;
    barrier& operator=(barrier const&) = delete;

    [[nodiscard]] inline arrival_token arrive(ptrdiff_t update = 1) {

        auto const old = phase_arrived_expected.fetch_add(arrived_unit, std::memory_order_acq_rel);
        if((old ^ (old + arrived_unit)) & phase_bit) {
            phase_arrived_expected.fetch_add((old & expected_mask) << 32, std::memory_order_relaxed);
#ifndef __BARRIER_NO_WAIT
            atomic_notify_all(&phase_arrived_expected);
#endif
        }
        return old & phase_bit;
    }
    inline void wait(arrival_token&& phase) const {

        while(1) {
            uint64_t const current = phase_arrived_expected.load(std::memory_order_acquire);
            if((current & phase_bit) != phase)
                return;
#ifndef __BARRIER_NO_WAIT
            atomic_wait_explicit(&phase_arrived_expected, current, std::memory_order_relaxed);
#endif
        }
    }
	inline void arrive_and_wait() {
        wait(arrive());
	}
    inline void arrive_and_drop() {
        phase_arrived_expected.fetch_add(expected_unit, std::memory_order_relaxed);
        (void)arrive();
    }
};

#endif //__BARRIER_NO_SPECIALIZATION

#endif //__BARRIER_NO_BUTTERFLY
