// -------------------------------------------------------------------------------------------------
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: (C) 2022 Jayesh Badwaik <j.badwaik@fz-juelich.de>
// -------------------------------------------------------------------------------------------------

#ifndef NOLA_STRIDED_ITERATOR_HPP
#define NOLA_STRIDED_ITERATOR_HPP

#include <iterator>
#include <memory>
#include <nola/qualifier.hpp>
#include <type_traits>

namespace nola {
/** Stride Iterator
 **/
template <typename Iter>
class strided_iterator {

public:
  using iterator_type = Iter;

private:
  using iterator_traits = std::iterator_traits<iterator_type>;

public:
  using value_type = typename iterator_traits::value_type;
  using difference_type = typename iterator_traits::difference_type;
  using pointer = typename iterator_traits::pointer;
  using reference = typename iterator_traits::reference;
  using iterator_category = std::contiguous_iterator_tag;

private:
  iterator_type iter_;
  difference_type stride_;

public:
  template <typename U>
    requires std::is_convertible_v<U, Iter>
  NOLA_HOST_DEVICE
  constexpr explicit strided_iterator(
    strided_iterator<U> const& u, difference_type stride) noexcept;

  NOLA_HOST_DEVICE
  constexpr explicit strided_iterator(iterator_type i, difference_type stride) noexcept;

public:
  NOLA_HOST_DEVICE
  constexpr auto operator*() const noexcept -> reference;
  NOLA_HOST_DEVICE
  constexpr auto operator->() const noexcept -> pointer;

  NOLA_HOST_DEVICE
  constexpr auto operator++() noexcept -> strided_iterator&;
  NOLA_HOST_DEVICE
  constexpr auto operator--() noexcept -> strided_iterator&;

  NOLA_HOST_DEVICE
  constexpr auto operator++(int) noexcept -> strided_iterator;
  NOLA_HOST_DEVICE
  constexpr auto operator--(int) noexcept -> strided_iterator;

  NOLA_HOST_DEVICE
  constexpr auto operator+=(difference_type n) noexcept -> strided_iterator&;
  NOLA_HOST_DEVICE
  constexpr auto operator-=(difference_type n) noexcept -> strided_iterator&;

  NOLA_HOST_DEVICE
  constexpr auto operator[](difference_type n) const noexcept -> reference;

  NOLA_HOST_DEVICE
  constexpr auto base() const noexcept -> iterator_type;

  NOLA_HOST_DEVICE
  constexpr auto strided() const noexcept -> difference_type;
};

// Deduction Guides
template <typename Iter>
strided_iterator(Iter a) -> strided_iterator<Iter>;

// Reference: https://github.com/llvm/llvm-project/commit/4118858b4e4d072ac2ceef6cbc52088438781f39
// We actually do need both a one-template-parameter and a two-template-parameter version of all the
// comparison operators, because if we have only the heterogeneous two-parameter version, then `x >
// x` is ambiguous:
//
//    template<class T, class U> int f(S<T>, S<U>) { return 1; }
//    template<class T> int f(T, T) { return 2; }  // rel_ops
//    S<int> s; f(s,s);  // ambiguous between #1 and #2
//
// Adding the one-template-parameter version fixes the ambiguity:
//
//    template<class T, class U> int f(S<T>, S<U>) { return 1; }
//    template<class T> int f(T, T) { return 2; }  // rel_ops
//    template<class T> int f(S<T>, S<T>) { return 3; }
//    S<int> s; f(s,s);  // #3 beats both #1 and #2
template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator==(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator==(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator!=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator!=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator<(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator<(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator<=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator<=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator>(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator>(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator>=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator>=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator+(
  strided_iterator<Iter> const& x, typename strided_iterator<Iter>::difference_type n)
  -> strided_iterator<Iter>;

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator+(
  typename strided_iterator<Iter>::difference_type n, strided_iterator<Iter> const& x)
  -> strided_iterator<Iter>;

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator-(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> typename strided_iterator<Iter1>::difference_type;

// -------------------------------------------------------------------------------------------------
// Implementation
// -------------------------------------------------------------------------------------------------

template <typename Iter>
template <typename U>
  requires std::is_convertible_v<U, Iter>
NOLA_HOST_DEVICE
constexpr strided_iterator<Iter>::strided_iterator(
  strided_iterator<U> const& u, difference_type strided) noexcept
: iter_(u.base()), stride_(strided)
{
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr strided_iterator<Iter>::strided_iterator(
  iterator_type i, difference_type strided) noexcept
: iter_(i), stride_(strided)
{
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator*() const noexcept -> reference
{
  return *iter_;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator->() const noexcept -> pointer
{
  return std::to_address(iter_);
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator++() noexcept -> strided_iterator&
{
  iter_ += stride_;
  return *this;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator--() noexcept -> strided_iterator&
{
  iter_ -= stride_;
  return *this;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator++(int) noexcept -> strided_iterator
{
  auto tmp = strided_iterator(*this);
  (*this) += stride_;
  return tmp;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator--(int) noexcept -> strided_iterator
{
  auto tmp = strided_iterator(*this);
  (*this) -= stride_;
  return tmp;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator+=(difference_type n) noexcept -> strided_iterator&
{
  iter_ += n;
  return *this;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator-=(difference_type n) noexcept -> strided_iterator&
{
  iter_ -= n;
  return *this;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::operator[](difference_type n) const noexcept -> reference
{
  return iter_[n * stride_];
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::base() const noexcept -> iterator_type
{
  return iter_;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto strided_iterator<Iter>::strided() const noexcept -> difference_type
{
  return stride_;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator==(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return left.base() == right.base();
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator==(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left.base() == right.base();
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator!=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return (not(left == right));
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator!=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left != right;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator<(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return left.base() < right.base();
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator<(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left.base() < right.base();
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator<=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return left.base() <= right.base();
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator<=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left.base() <= right.base();
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator>(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return left.base() > right.base();
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator>(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left.base() > right.base();
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator>=(strided_iterator<Iter> const& left, strided_iterator<Iter> const& right)
  -> bool
{
  return left.base() >= right.base();
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator>=(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> bool
{
  return left.base() >= right.base();
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator+(
  strided_iterator<Iter> const& x, typename strided_iterator<Iter>::difference_type n)
  -> strided_iterator<Iter>
{
  auto result = x;
  return result += n;
}

template <typename Iter>
NOLA_HOST_DEVICE
constexpr auto operator+(
  typename strided_iterator<Iter>::difference_type n, strided_iterator<Iter> const& x)
  -> strided_iterator<Iter>
{
  auto result = x;
  return result += n;
}

template <typename Iter1, typename Iter2>
NOLA_HOST_DEVICE
constexpr auto operator-(strided_iterator<Iter1> const& left, strided_iterator<Iter2> const& right)
  -> typename strided_iterator<Iter1>::difference_type
{
  return left.base() - right.base();
}

} // namespace nola

#endif // NOLA_STRIDED_ITERATOR_HPP
