// -------------------------------------------------------------------------------------------------
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: (C) 2022 Jayesh Badwaik <j.badwaik@fz-juelich.de>
// -------------------------------------------------------------------------------------------------

#ifndef NOLA_TAG_INVOKE_HPP
#define NOLA_TAG_INVOKE_HPP

#include <nola/qualifier.hpp>
#include <utility>

namespace nola {
template <class Tag, class... Args>
concept tag_invokable = //
  requires(Tag tag, Args&&... args) {
    tag_invoke(std::forward<Tag>(tag), std::forward<Args...>(args...));
  };

template <class Ret, class Tag, class... Args>
concept tag_invokable_r = //
  requires(Tag tag, Args&&... args) {
    {
      static_cast<Ret>(tag_invoke(std::forward<Tag>(tag), std::forward<Args...>(args...)))
    };
  };

template <class Tag, class... Args>
concept nothrow_tag_invokable = //
  tag_invokable<Tag, Args...>
  && requires(Tag tag, Args&&... args) {
       {
         tag_invoke(std::forward<Tag>(tag), std::forward<Args...>(args...))
       } noexcept;
     };

template <class Tag, class... Args>
using tag_invoke_result_t = decltype(tag_invoke(std::declval<Tag>(), std::declval<Args>()...));

template <class Tag, class... Args>
struct tag_invoke_result {};

template <class Tag, class... Args>
  requires tag_invokable<Tag, Args...>
struct tag_invoke_result<Tag, Args...> {
  using type = tag_invoke_result_t<Tag, Args...>;
};

struct tag_invoke_t {
  template <class Tag, class... Args>
    requires tag_invokable<Tag, Args...>
  NOLA_HOST_DEVICE
  constexpr auto operator()(Tag tag, Args&&... args) const
    noexcept(nothrow_tag_invokable<Tag, Args...>) -> tag_invoke_result_t<Tag, Args...>
  {
    return tag_invoke(std::forward<Tag>(tag), std::forward<Args...>(args...));
  }
};

inline constexpr tag_invoke_t tag_invoke{};

template <auto& Tag>
using tag_t = std::decay_t<decltype(Tag)>;
} // namespace nola

#endif // NOLA_TAG_INVOKE_HPP
