/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Modifications Copyright© 2019-2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#pragma once

#include <thrust/detail/event_error.h>
#include <thrust/execution_policy.h>
#include <thrust/host_vector.h>
#include <thrust/limits.h>
#include <thrust/mr/allocator.h>
#include <thrust/random.h>

#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <iterator>
#include <random>
#include <type_traits>
#include <vector>
#if !_THRUST_HAS_DEVICE_SYSTEM_STD
// Use rocprim::numeric_limits if thrust/detail/type_traits.h uses rocprim::arithmetic
#  include <limits>
#endif

#include "test_seed.hpp"

#define TEST_EVENT_WAIT(e) test_event_wait(e)

// for demangling the result of type_info.name()
// with msvc, type_info.name() is already demangled
#ifdef __GNUC__
#  include <cxxabi.h>
#endif // __GNUC__

#include <cstdlib>
#include <string>

// HIP API
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_HIP
#  include <hip/hip_runtime.h>
#  include <hip/hip_runtime_api.h>

// GoogleTest-compatible HIP_CHECK macro. FAIL is called to log the Google Test trace.
// The lambda is invoked immediately as assertions that generate a fatal failure can
// only be used in void-returning functions.
#  ifndef HIP_CHECK
#    define HIP_CHECK(condition)                                                 \
      do                                                                         \
      {                                                                          \
        hipError_t error = condition;                                            \
        if (error != hipSuccess)                                                 \
        {                                                                        \
          [error]() {                                                            \
            FAIL() << "HIP error " << error << ": " << hipGetErrorString(error); \
          }();                                                                   \
          exit(error);                                                           \
        }                                                                        \
      } while (0)
#  endif

#endif // THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_HIP


namespace test
{

inline char* get_env(const char* name)
{
  char* env;
#ifdef _MSC_VER
  size_t len;
  errno_t err = _dupenv_s(&env, &len, name);
  if (err)
  {
    return nullptr;
  }
#else
  env = std::getenv(name);
#endif
  return env;
}

inline void clean_env(char* name)
{
#ifdef _MSC_VER
  if (name != nullptr)
  {
    free(name);
  }
#endif
  (void) name;
}

inline int set_device_from_ctest()
{
  static const std::string rg0 = "CTEST_RESOURCE_GROUP_0";
  char* env                    = get_env(rg0.c_str());
  int device                   = 0;
  if (env != nullptr)
  {
    std::string amdgpu_target(env);
    std::transform(
      amdgpu_target.cbegin(),
      amdgpu_target.cend(),
      amdgpu_target.begin(),
      // Feeding std::toupper plainly results in implicitly truncating conversions between int and char triggering
      // warnings.
      [](unsigned char c) {
        return static_cast<char>(std::toupper(c));
      });
    char* env_reqs = get_env((rg0 + "_" + amdgpu_target).c_str());
    std::string reqs(env_reqs);
    device = std::atoi(reqs.substr(reqs.find(':') + 1, reqs.find(',') - (reqs.find(':') + 1)).c_str());
    clean_env(env_reqs);
    HIP_CHECK(hipSetDevice(device));
  }
  clean_env(env);
  return device;
}
} // namespace test

// If enabled, set up the database for inter-run bitwise reproducibility testing.
// Inter-run testing is enabled through the following environment variables:
// ROCTHRUST_BWR_PATH - path to the database (or where it should be created)
// ROCTHRUST_BWR_GENERATE - if set to 1, info about any function calls not
// found in the database will be inserted. No errors will be reported in this mode.
namespace inter_run_bwr
{
// Disable this testing by default.
bool enabled = false;

// This code doesn't need to be visible outside this file.
namespace
{
const static std::string path_env     = "ROCTHRUST_BWR_PATH";
const static std::string generate_env = "ROCTHRUST_BWR_GENERATE";

// Check the environment variables to see if the database should be
// instantiated, and if so, what mode it should be in.
std::unique_ptr<BitwiseReproDB> create_db()
{
  // Get the path to the database from an environment variable.
  const char* db_path = std::getenv(path_env.c_str());
  const char* db_mode = std::getenv(generate_env.c_str());
  if (db_path)
  {
    // Check if we are allowed to insert rows into the database if
    // we encounter calls that aren't already recorded.
    BitwiseReproDB::Mode mode = BitwiseReproDB::Mode::test_mode;
    if (db_mode && std::stoi(db_mode) > 0)
    {
      mode = BitwiseReproDB::Mode::generate_mode;
    }

    enabled = true;
    return std::make_unique<BitwiseReproDB>(db_path, mode);
  }
  else if (db_mode)
  {
    throw std::runtime_error("ROCTHRUST_BWR_GENERATE is defined, but no database path was given.\n"
                             "Please set ROCTHRUST_BWR_PATH to the database path.");
  }

  return nullptr;
}
} // namespace

// Create/open the run-to-run bitwise reproducibility database.
std::unique_ptr<BitwiseReproDB> db = create_db();
} // namespace inter_run_bwr

#ifdef __GNUC__
inline std::string demangle(const char* name)
{
  int status     = 0;
  char* realname = abi::__cxa_demangle(name, 0, 0, &status);
  std::string result(realname);
  std::free(realname);

  return result;
}
#else
inline std::string demangle(const char* name)
{
  return name;
}
#endif

/// Safe sign-mixed comparisons, negative values always compare less
/// than any values of unsigned types (in contrast to the behaviour of the built-in comparison operator)
/// This is a backport of a C++20 standard library feature to C++14
template <class T, class U>
constexpr auto cmp_less(T t, U u) noexcept -> std::enable_if_t<
  std::is_signed<T>::value == std::is_signed<U>::value || !std::is_integral<T>::value || !std::is_integral<U>::value,
  bool>
{
  return t < u;
}

template <class T, class U>
constexpr auto cmp_less(T t, U u) noexcept
  -> std::enable_if_t<std::is_signed<T>::value && !std::is_signed<U>::value && std::is_integral<T>::value, bool>
{
  // U is unsigned
  return t < 0 || std::make_unsigned_t<T>(t) < u;
}

template <class T, class U>
constexpr auto cmp_less(T t, U u) noexcept
  -> std::enable_if_t<!std::is_signed<T>::value && std::is_signed<U>::value && std::is_integral<U>::value, bool>
{
  // T is unsigned U is signed
  return u >= 0 && t < std::make_unsigned_t<U>(u);
}

template <class T, class U>
constexpr bool cmp_greater(T t, U u) noexcept
{
  return cmp_less(u, t);
}
// Backport of saturate_cast from C++26 to C++14
// From
// https://github.com/llvm/llvm-project/blob/52b18430ae105566f26152c0efc63998301b1134/libcxx/include/__numeric/saturation_arithmetic.h#L97
// licensed under the MIT license
template <typename Res, typename T>
constexpr Res saturate_cast(T x) noexcept
{
  // Handle overflow
  if (cmp_less(x, std::numeric_limits<Res>::min()))
  {
    return std::numeric_limits<Res>::min();
  }
  if (cmp_greater(x, std::numeric_limits<Res>::max()))
  {
    return std::numeric_limits<Res>::max();
  }
  // No overflow
  return static_cast<Res>(x);
}

class UnitTestException
{
public:
  std::string message;

  UnitTestException() {}
  UnitTestException(const std::string& msg)
      : message(msg)
  {}

  friend std::ostream& operator<<(std::ostream& os, const UnitTestException& e)
  {
    return os << e.message;
  }

  template <typename T>
  UnitTestException& operator<<(const T& t)
  {
    std::ostringstream oss;
    oss << t;
    message += oss.str();
    return *this;
  }
};

class UnitTestError : public UnitTestException
{
public:
  UnitTestError() {}
  UnitTestError(const std::string& msg)
      : UnitTestException(msg)
  {}
};

class UnitTestKnownFailure : public UnitTestException
{
public:
  UnitTestKnownFailure() {}
  UnitTestKnownFailure(const std::string& msg)
      : UnitTestException(msg)
  {}
};

class UnitTestFailure : public UnitTestException
{
public:
  UnitTestFailure() {}
  UnitTestFailure(const std::string& msg)
      : UnitTestException(msg)
  {}
};

template <typename T>
std::string type_name(void)
{
  return demangle(typeid(T).name());
} // end type_name()

template <typename Event>
__host__ void test_event_wait(Event&& e)
{
  ASSERT_EQ(true, e.valid_stream());

  // Call at least once the hipDeviceSynchronize()
  // before the stream ready state check
  e.wait();
  while (!e.ready())
  {
    e.wait();
  }

  ASSERT_EQ(true, e.valid_stream());
  ASSERT_EQ(true, e.ready());
}

std::vector<size_t> get_sizes()
{
  std::vector<size_t> sizes = {
    0, 1, 2, 12, 63, 64, 211, 256, 344, 1024, 2048, 5096, 34567, (1 << 17) - 1220, 1000000, (1 << 20) - 123};
  return sizes;
}

std::vector<seed_type> get_seeds()
{
  std::vector<seed_type> seeds;
  std::random_device rng;
  std::copy(prng_seeds.begin(), prng_seeds.end(), std::back_inserter(seeds));
  std::generate_n(std::back_inserter(seeds), rng_seed_count, [&]() {
    return rng();
  });
  return seeds;
}

template <class T, class enable = void>
struct get_default_limits;

template <class T>
struct get_default_limits<T, std::enable_if_t<std::is_integral<T>::value>>
{
  static inline T min()
  {
    return std::numeric_limits<T>::min();
  }
  static inline T max()
  {
    return std::numeric_limits<T>::max();
  }
};

template <class T>
struct get_default_limits<T, std::enable_if_t<std::is_floating_point<T>::value>>
{
  static inline T min()
  {
    return T(-1);
  }
  static inline T max()
  {
    return T(1);
  }
};

template <class T>
inline auto get_random_data(size_t size, T, T, seed_type seed) ->
  typename std::enable_if_t<std::is_same_v<T, bool>, thrust::host_vector<T>>
{
  std::random_device rd;
  std::default_random_engine gen(rd());
  gen.seed(seed);
  std::bernoulli_distribution distribution(0.5);
  thrust::host_vector<T> data(size);
  std::generate(data.begin(), data.end(), [&]() {
    return distribution(gen);
  });
  return data;
}

#if defined(_MSC_VER)
template <class T>
inline auto get_random_data(size_t size, T min, T max, seed_type seed) ->
  typename std::enable_if_t<std::is_same_v<T, signed char> || (std::is_same_v<T, char> && std::is_signed_v<char>),
                            thrust::host_vector<T>>
{
  std::random_device rd;
  std::default_random_engine gen(rd());
  gen.seed(seed);
  std::uniform_int_distribution<int> distribution(static_cast<int>(min), static_cast<int>(max));
  thrust::host_vector<T> data(size);
  std::generate(data.begin(), data.end(), [&]() {
    return static_cast<T>(distribution(gen));
  });
  return data;
}

template <class T>
inline auto get_random_data(size_t size, T min, T max, seed_type seed) ->
  typename std::enable_if_t<std::is_same_v<T, unsigned char> || (std::is_same_v<T, char> && std::is_unsigned_v<char>),
                            thrust::host_vector<T>>
{
  std::random_device rd;
  std::default_random_engine gen(rd());
  gen.seed(seed);
  std::uniform_int_distribution<int> distribution(static_cast<unsigned int>(min), static_cast<unsigned int>(max));
  thrust::host_vector<T> data(size);
  std::generate(data.begin(), data.end(), [&]() {
    return static_cast<T>(distribution(gen));
  });
  return data;
}
#endif

template <class T>
inline auto get_random_data(size_t size, T min, T max, seed_type seed) -> typename std::enable_if_t<
  rocprim::is_integral<T>::value && !std::is_same_v<T, bool>
#if defined(_MSC_VER)
    && !std::is_same_v<T, signed char> && !std::is_same_v<T, unsigned char> && !std::is_same_v<T, char>
#endif
  ,
  thrust::host_vector<T>>
{
  std::random_device rd;
  std::default_random_engine gen(rd());
  gen.seed(seed);
  std::uniform_int_distribution<T> distribution(saturate_cast<T>(min), saturate_cast<T>(max));
  thrust::host_vector<T> data(size);
  std::generate(data.begin(), data.end(), [&]() {
    return distribution(gen);
  });
  return data;
}

template <class T>
inline auto get_random_data(size_t size, T min, T max, seed_type seed) ->
  typename std::enable_if_t<rocprim::is_floating_point<T>::value, thrust::host_vector<T>>
{
  std::random_device rd;
  std::default_random_engine gen(rd());
  gen.seed(seed);
  std::uniform_real_distribution<T> distribution(min, max);
  thrust::host_vector<T> data(size);
  std::generate(data.begin(), data.end(), [&]() {
    return distribution(gen);
  });
  return data;
}

template <class T>
struct custom_compare_less
{
  __host__ __device__ bool operator()(const T& lhs, const T& rhs) const
  {
    return lhs < rhs;
  }
}; // end less

struct user_swappable
{
  inline __host__ __device__ user_swappable(bool swapped = false)
      : was_swapped(swapped)
  {}

  bool was_swapped;
};

inline __host__ __device__ bool operator==(const user_swappable& x, const user_swappable& y)
{
  return x.was_swapped == y.was_swapped;
}

inline __host__ __device__ void swap(user_swappable& x, user_swappable& y)
{
  x.was_swapped = true;
  y.was_swapped = false;
}

class my_system : public thrust::device_execution_policy<my_system>
{
public:
  my_system(int)
      : correctly_dispatched(false)
      , num_copies(0)
  {}

  my_system(const my_system& other)
      : correctly_dispatched(false)
      , num_copies(other.num_copies + 1)
  {}

  void validate_dispatch()
  {
    correctly_dispatched = (num_copies == 0);
  }

  bool is_valid()
  {
    return correctly_dispatched;
  }

private:
  bool correctly_dispatched;

  // count the number of copies so that we can validate
  // that dispatch does not introduce any
  unsigned int num_copies;

  // disallow default construction
  my_system();
};

struct my_tag : thrust::device_execution_policy<my_tag>
{};

template <typename T, unsigned int N>
struct FixedVector
{
  T data[N];

  __host__ __device__ FixedVector()
  {
#pragma nounroll
    for (unsigned int i = 0; i < N; i++)
    {
      data[i] = T();
    }
  }

  __host__ __device__ FixedVector(T init)
  {
#pragma nounroll
    for (unsigned int i = 0; i < N; i++)
    {
      data[i] = init;
    }
  }

  __host__ __device__ FixedVector operator+(const FixedVector& bs) const
  {
    FixedVector output;
#pragma nounroll
    for (unsigned int i = 0; i < N; i++)
    {
      output.data[i] = data[i] + bs.data[i];
    }
    return output;
  }

  __host__ __device__ bool operator<(const FixedVector& bs) const
  {
#pragma nounroll
    for (unsigned int i = 0; i < N; i++)
    {
      if (data[i] < bs.data[i])
      {
        return true;
      }
      else if (bs.data[i] < data[i])
      {
        return false;
      }
    }
    return false;
  }

  __host__ __device__ bool operator==(const FixedVector& bs) const
  {
#pragma nounroll
    for (unsigned int i = 0; i < N; i++)
    {
      if (!(data[i] == bs.data[i]))
      {
        return false;
      }
    }
    return true;
  }
};

template <typename Key, typename Value>
struct key_value
{
  using key_type   = Key;
  using value_type = Value;

  __host__ __device__ key_value(void)
      : key()
      , value()
  {}

  __host__ __device__ key_value(key_type k, value_type v)
      : key(k)
      , value(v)
  {}

  __host__ __device__ bool operator<(const key_value& rhs) const
  {
    return key < rhs.key;
  }

  __host__ __device__ bool operator>(const key_value& rhs) const
  {
    return key > rhs.key;
  }

  __host__ __device__ bool operator==(const key_value& rhs) const
  {
    return key == rhs.key && value == rhs.value;
  }

  __host__ __device__ bool operator!=(const key_value& rhs) const
  {
    return !operator==(rhs);
  }

  friend std::ostream& operator<<(std::ostream& os, const key_value& kv)
  {
    return os << "(" << kv.key << ", " << kv.value << ")";
  }

  key_type key;
  value_type value;
};

inline unsigned int hash(unsigned int a)
{
  a = (a + 0x7ed55d16) + (a << 12);
  a = (a ^ 0xc761c23c) ^ (a >> 19);
  a = (a + 0x165667b1) + (a << 5);
  a = (a + 0xd3a2646c) ^ (a << 9);
  a = (a + 0xfd7046c5) + (a << 3);
  a = (a ^ 0xb55a4f09) ^ (a >> 16);
  return a;
}

template <typename T, typename = void>
struct generate_random_integer;

template <typename T>
struct generate_random_integer<
  T,
  typename thrust::detail::disable_if<thrust::detail::is_non_bool_arithmetic<T>::value>::type>
{
  T operator()(unsigned int i) const
  {
    thrust::default_random_engine rng(hash(i));

    return static_cast<T>(rng());
  }
};

template <typename T>
struct generate_random_integer<T, typename ::std::enable_if<thrust::detail::is_non_bool_integral<T>::value>::type>
{
  T operator()(unsigned int i) const
  {
    thrust::default_random_engine rng(hash(i));
    thrust::uniform_int_distribution<T> dist;

    return static_cast<T>(dist(rng));
  }
};

template <typename T>
struct generate_random_integer<T, typename ::std::enable_if<::std::is_floating_point<T>::value>::type>
{
  T operator()(unsigned int i) const
  {
    T const min = std::numeric_limits<T>::min();
    T const max = std::numeric_limits<T>::max();

    thrust::default_random_engine rng(hash(i));
    thrust::uniform_real_distribution<T> dist(min, max);

    return static_cast<T>(dist(rng));
  }
};

template <>
struct generate_random_integer<bool>
{
  bool operator()(unsigned int i) const
  {
    thrust::default_random_engine rng(hash(i));
    thrust::uniform_int_distribution<unsigned int> dist(0, 1);

    return dist(rng) == 1;
  }
};

template <typename T>
struct generate_random_sample
{
  T operator()(unsigned int i) const
  {
    thrust::default_random_engine rng(hash(i));
    thrust::uniform_int_distribution<unsigned int> dist(0, 20);

    return static_cast<T>(dist(rng));
  }
};

template <typename T>
thrust::host_vector<T> random_integers(const size_t N)
{
  thrust::host_vector<T> vec(N);
  thrust::transform(thrust::counting_iterator<size_t>(0),
                    thrust::counting_iterator<size_t>(N),
                    vec.begin(),
                    generate_random_integer<T>());

  return vec;
}

template <typename T>
T random_integer()
{
  return generate_random_integer<T>()(0);
}

template <typename T>
thrust::host_vector<T> random_samples(const size_t N)
{
  thrust::host_vector<T> vec(N);
  thrust::transform(thrust::counting_iterator<size_t>(0),
                    thrust::counting_iterator<size_t>(N),
                    vec.begin(),
                    generate_random_sample<T>());

  return vec;
}

// Use this with counting_iterator to avoid generating a range larger than we
// can represent.
template <typename T>
typename THRUST_NS_QUALIFIER::detail::disable_if<_THRUST_STD::is_floating_point<T>::value, T>::type
truncate_to_max_representable(std::size_t n)
{
  // Use rocprim::numeric_limits if thrust/detail/type_traits.h uses rocprim::arithmetic
  return static_cast<T>(
    THRUST_NS_QUALIFIER::min<std::size_t>(n, static_cast<std::size_t>(_THRUST_STD::numeric_limits<T>::max())));
}

// TODO: This probably won't work for `half`.
template <typename T>
typename _THRUST_STD::enable_if_t<_THRUST_STD::is_floating_point<T>::value, T>
truncate_to_max_representable(std::size_t n)
{
  // Use rocprim::numeric_limits if thrust/detail/type_traits.h uses rocprim::arithmetic
  return THRUST_NS_QUALIFIER::min<T>(static_cast<T>(n), _THRUST_STD::numeric_limits<T>::max());
}

enum threw_status
{
  did_not_throw,
  threw_wrong_type,
  threw_right_type_but_wrong_value,
  threw_right_type
};

void check_assert_throws(
  threw_status s, std::string const& exception_name, std::string const& file_name = "unknown", int line_number = -1)
{
  switch (s)
  {
    case did_not_throw: {
      UnitTestFailure f;
      f << "[" << file_name << ":" << line_number << "] did not throw anything";
      throw f;
    }
    case threw_wrong_type: {
      UnitTestFailure f;
      f << "[" << file_name << ":" << line_number << "] did not throw an "
        << "object of type " << exception_name;
      throw f;
    }
    case threw_right_type_but_wrong_value: {
      UnitTestFailure f;
      f << "[" << file_name << ":" << line_number << "] threw an object of the "
        << "correct type (" << exception_name << ") but wrong value";
      throw f;
    }
    case threw_right_type:
      break;
    default: {
      UnitTestFailure f;
      f << "[" << file_name << ":" << line_number << "] encountered an "
        << "unknown error";
      throw f;
    }
  }
}

template <typename Future>
__host__ void test_future_value_retrieval(Future&& f, decltype(f.extract())& return_value)
{
  ASSERT_EQ(true, f.valid_stream());
  ASSERT_EQ(true, f.valid_content());

  auto const r0 = f.get();
  auto const r1 = f.get();

  ASSERT_EQ(true, f.ready());
  ASSERT_EQ(true, f.valid_stream());
  ASSERT_EQ(true, f.valid_content());
  ASSERT_EQ(r0, r1);

  auto const r2 = f.extract();

  ASSERT_THROW(auto x = f.extract(); // cppcheck-suppress unknownMacro
               THRUST_UNUSED_VAR(x), thrust::event_error);

  ASSERT_EQ(false, f.ready());
  ASSERT_EQ(false, f.valid_stream());
  ASSERT_EQ(false, f.valid_content());
  ASSERT_EQ(r2, r1);
  ASSERT_EQ(r2, r0);

  return_value = r2;
}

namespace
{
// Values of relative error for non-assotiative operations
// (+, -, *) and type conversions for floats
// They are doubled from 1 / (1 << mantissa_bits) as we compare in tests
// the results of _two_ sequences of operations with different order
// For all other operations (i.e. integer arithmetics) default 0 is used
template <class T>
constexpr float precision = 0;

template <>
constexpr float precision<double> = 2.0f / (1ll << 52);

template <>
constexpr float precision<float> = 2.0f / (1ll << 23);

template <>
constexpr float precision<rocprim::half> = 2.0f / (1ll << 10);

template <>
constexpr float precision<rocprim::bfloat16> = 2.0f / (1ll << 7);

template <class T>
constexpr float precision<const T> = precision<T>;
} // namespace

template <class T, typename std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
inline void test_equality(const T& hvalue, const T& dvalue, const size_t ops = 1)
{
  // Check bitwise equality for +NaN, -NaN, +0.0, -0.0, +inf, -inf.
  if (std::memcmp(&hvalue, &dvalue, sizeof(T)) == 0)
  {
    return;
  }

  // Check value difference based on precision threshold
  // relative difference or absolute difference with small values
  auto tolerance = double(ops) * std::max<T>(std::abs(T(precision<T>) * hvalue), T(precision<T>));
  ASSERT_NEAR(hvalue, dvalue, tolerance);
}

template <class T, typename std::enable_if_t<std::is_integral<T>::value>* = nullptr>
inline void test_equality(const T& hvalue, const T& dvalue, const size_t)
{
  ASSERT_EQ(hvalue, dvalue);
}

// Test vector comparing host and device results
// If type is integral check for equality, if floating
// check absolute or relative difference
template <class T>
void test_equality(const thrust::host_vector<T>& hvalue, const thrust::device_vector<T>& dvalue, const size_t ops = 1)
{
  thrust::host_vector<T> hvalue_d(dvalue);
  ASSERT_EQ(hvalue.size(), hvalue_d.size());
  for (size_t i = 0; i < hvalue.size(); i++)
  {
    test_equality(hvalue[i], hvalue_d[i], ops);
  }
}

template <class T>
void test_equality_scan(const thrust::host_vector<T>& hvalue, const thrust::device_vector<T>& dvalue)
{
  thrust::host_vector<T> hvalue_d(dvalue);
  ASSERT_EQ(hvalue.size(), hvalue_d.size());
  for (size_t i = 0; i < hvalue.size(); i++)
  {
    test_equality(hvalue[i], hvalue_d[i], i);
  }
}

// Test vector of pairs comparing host and device results
// If type is integral check for equality, if floating
// check absolute or relative difference
template <typename X, typename Y, template <typename, typename> class Pair>
void test_equality(
  const thrust::host_vector<Pair<X, Y>>& hvalue, const thrust::device_vector<Pair<X, Y>>& dvalue, const size_t ops = 1)
{
  thrust::host_vector<Pair<X, Y>> hvalue_d(dvalue);
  ASSERT_EQ(hvalue.size(), hvalue_d.size());
  for (size_t i = 0; i < hvalue.size(); i++)
  {
    test_equality(hvalue[i].first, hvalue_d[i].first, ops);
    test_equality(hvalue[i].second, hvalue_d[i].second, ops);
  }
}

template <typename X, typename Y, template <typename, typename> class Pair>
void test_equality_pair_scan(const thrust::host_vector<Pair<X, Y>>& hvalue,
                             const thrust::device_vector<Pair<X, Y>>& dvalue)
{
  thrust::host_vector<Pair<X, Y>> hvalue_d(dvalue);
  ASSERT_EQ(hvalue.size(), hvalue_d.size());
  for (size_t i = 0; i < hvalue.size(); i++)
  {
    test_equality(hvalue[i].first, hvalue_d[i].first, i);
    test_equality(hvalue[i].second, hvalue_d[i].second, i);
  }
}

#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_HIP
#  define THRUST_DEVICE_BACKEND                 hip
#  define THRUST_DEVICE_BACKEND_DETAIL          hip_rocprim
#  define SPECIALIZE_DEVICE_RESOURCE_NAME(name) hip##name
#elif defined(__NVCC__) || defined(_NVHPC_CUDA)                                \
  || (defined(__CUDA__) && THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_CLANG) \
  || THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_NVRTC
#  define THRUST_DEVICE_BACKEND                 cuda
#  define THRUST_DEVICE_BACKEND_DETAIL          cuda_cub
#  define SPECIALIZE_DEVICE_RESOURCE_NAME(name) cuda##name
#endif
