/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Modifications Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <algorithm>
#include <cstdint>
#include <functional>

#include "test_param_fixtures.hpp"
#include "test_utils.hpp"

const size_t NUM_SAMPLES = 10000;

template <class InputVector, class OutputVector, class Operator, class ReferenceOperator>
void TestBinaryFunctional()
{
  using InputType  = typename InputVector::value_type;
  using OutputType = typename OutputVector::value_type;

  thrust::host_vector<InputType> std_input1 = random_samples<InputType>(NUM_SAMPLES);
  thrust::host_vector<InputType> std_input2 = random_samples<InputType>(NUM_SAMPLES);
  thrust::host_vector<OutputType> std_output(NUM_SAMPLES);

  InputVector input1 = std_input1;
  InputVector input2 = std_input2;
  OutputVector output(NUM_SAMPLES);

  thrust::transform(input1.begin(), input1.end(), input2.begin(), output.begin(), Operator());
  thrust::transform(std_input1.begin(), std_input1.end(), std_input2.begin(), std_output.begin(), ReferenceOperator());

  // Note: FP division is not bit-equal, even when nvcc is invoked with --prec-div
  ASSERT_EQ(output, std_output);
}

// XXX add bool to list
// Instantiate a macro for all integer-like data types
// clang-format off
#define INSTANTIATE_INTEGER_TYPES(Macro, vector_type, operator_name)   \
Macro(vector_type, operator_name, int8_t  )                            \
Macro(vector_type, operator_name, uint8_t )                            \
Macro(vector_type, operator_name, int16_t )                            \
Macro(vector_type, operator_name, uint16_t)                            \
Macro(vector_type, operator_name, int32_t )                            \
Macro(vector_type, operator_name, uint32_t)                            \
Macro(vector_type, operator_name, int64_t )                            \
Macro(vector_type, operator_name, uint64_t)
// clang-format on

// Instantiate a macro for all integer and floating point data types
#define INSTANTIATE_ALL_TYPES(Macro, vector_type, operator_name) \
  INSTANTIATE_INTEGER_TYPES(Macro, vector_type, operator_name)   \
  Macro(vector_type, operator_name, float)

// XXX revert OutputVector<T> back to bool
// op(T,T) -> bool
#define INSTANTIATE_BINARY_LOGICAL_FUNCTIONAL_TEST(vector_type, operator_name, data_type) \
  TestBinaryFunctional<thrust::vector_type<data_type>,                                    \
                       thrust::vector_type<data_type>,                                    \
                       thrust::operator_name<data_type>,                                  \
                       std::operator_name<data_type>>();

// op(T,T) -> bool
#define DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(operator_name, OperatorName)                      \
  TEST(FunctionalLogicalTests, Test##OperatorName##FunctionalHost)                                   \
  {                                                                                                  \
    SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());         \
                                                                                                     \
    INSTANTIATE_ALL_TYPES(INSTANTIATE_BINARY_LOGICAL_FUNCTIONAL_TEST, host_vector, operator_name);   \
  }                                                                                                  \
  TEST(FunctionalLogicalTests, Test##OperatorName##FunctionalDevice)                                 \
  {                                                                                                  \
    SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());         \
                                                                                                     \
    INSTANTIATE_ALL_TYPES(INSTANTIATE_BINARY_LOGICAL_FUNCTIONAL_TEST, device_vector, operator_name); \
  }

// Create the unit tests
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(equal_to, EqualTo);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(not_equal_to, NotEqualTo);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(greater, Greater);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(less, Less);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(greater_equal, GreaterEqual);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(less_equal, LessEqual);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(logical_and, LogicalAnd);
DECLARE_BINARY_LOGICAL_FUNCTIONAL_UNITTEST(logical_or, LogicalOr);
