/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/
/*
*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::getrfnp_batch
*       group API to perform batched LU factorization on a SYCL 
*       device (CPU, GPU).
*
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*
*******************************************************************************/

#include <CL/sycl.hpp>
#include "oneapi/mkl.hpp"
#include "common_for_examples.hpp"

//
// Main example for batched LU group API, consisting of initialization of
// multiple groups of general dense A matrices. The matrices in the batch
// are designated into groups, where each matrix in a given group has the
// same matrix dimensions (m, n, lda).
// This example demonstrates a way to set up the groups of matrices on the host and
// transfer the data to device memory to perform the batched LU operation:
// A = P * L * U
//
template <typename data_t>
void run_getrfnp_batch_group_example(const sycl::device &dev) {
    std::int64_t info = 0;

    // Asynchronous error handler
    auto error_handler = [&] (sycl::exception_list exceptions) {
        for (auto const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(oneapi::mkl::lapack::exception const &e) {
                // Handle LAPACK related exceptions happened during asynchronous call
                info = e.info();
                std::cout << "Unexpected exception caught during asynchronous LAPACK operation:\n" << e.what() << "\ninfo: " << e.info() << std::endl;
            } catch(sycl::exception const &e) {
                // Handle not LAPACK related exceptions happened during asynchronous call
                std::cout << "Unexpected exception caught during asynchronous operation:\n" << e.what() << std::endl;
                info = -1;
            }
        }
    };

    sycl::queue exec_queue(dev, error_handler);

    // Number of groups in the batch
    static constexpr int64_t GROUP_COUNT = 3;

    // Number of matrices in each group
    int64_t group_size[GROUP_COUNT] = {1, 4, 8};

    // Dimensions of matrices in each group
    int64_t m[GROUP_COUNT] = {32, 48, 64};
    int64_t n[GROUP_COUNT] = {48, 48, 48};
    int64_t lda[GROUP_COUNT] = {32, 50, 72};

    int64_t batch_size = 0;
    int64_t size_a = 0;
    std::cout << "\tgetrfnp_batch group details: " << std::endl;
    std::cout << "\t---------------------------- " << std::endl;
    std::cout << "\tgroup_count: " << GROUP_COUNT << std::endl;
    for (int igrp = 0 ; igrp < GROUP_COUNT; igrp++) {
        batch_size += group_size[igrp];
        size_a += group_size[igrp] * lda[igrp] * n[igrp];
        std::cout << "\tgroup "<< igrp << ":   m=" << m[igrp] << " n=" << n[igrp]
            << " lda=" << lda[igrp] << " group_size=" << group_size[igrp] << std::endl;
    }
    std::cout << "\t---------------------------- " << std::endl;

    // Set up host data
    data_t *A_h = sycl::malloc_host<data_t>(size_a, exec_queue);
    for (int igrp = 0, off = 0; igrp < GROUP_COUNT; igrp++)
        for (int j = 0; j < group_size[igrp]; j++, off += lda[igrp] * n[igrp])
            rand_getrfnp_matrix(&A_h[off], m[igrp], n[igrp], lda[igrp]);

    // Transfer host data to the device
    data_t *A_d = sycl::malloc_device<data_t>(size_a, exec_queue);
    exec_queue.copy(A_h, A_d, size_a).wait();
    
    // Set up array of pointers in a temporary host-side array
    data_t **temp_ptrs = sycl::malloc_host<data_t*>(batch_size, exec_queue);
    data_t *temp_d = A_d;
    for (int igrp = 0, idx = 0; igrp < GROUP_COUNT; igrp++)
        for (int j = 0; j < group_size[igrp]; j++, idx++, temp_d += lda[igrp] * n[igrp])
            temp_ptrs[idx] = temp_d;

    // Transfer array of pointers to the device
    data_t **A_ptrs_d = sycl::malloc_device<data_t*>(batch_size, exec_queue);
    exec_queue.copy(temp_ptrs, A_ptrs_d, batch_size).wait();

    // Call oneapi::mkl::lapack::getrfnp_batch API
    data_t *getrfnp_batch_scratchpad;
    try {
        int64_t getrfnp_batch_scratchpad_size =
            oneapi::mkl::lapack::getrfnp_batch_scratchpad_size<data_t>(exec_queue, m, n, lda, GROUP_COUNT, group_size);

        getrfnp_batch_scratchpad = sycl::malloc_device<data_t>(getrfnp_batch_scratchpad_size, exec_queue);

        oneapi::mkl::lapack::getrfnp_batch(exec_queue, m, n, A_ptrs_d, lda, GROUP_COUNT, group_size,
                getrfnp_batch_scratchpad, getrfnp_batch_scratchpad_size).wait_and_throw();
    } catch(oneapi::mkl::lapack::exception const &e) {
        // Handle LAPACK-specific exceptions that happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to LAPACK API:\nreason: " << e.what()
            << "\ninfo: " << e.info()
            << "\ndetail: " << e.detail() << std::endl;
        info = e.info();
    } catch(std::exception const &e) {
        // Handle other related exceptions that happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to SYCL API:\n" << e.what() << std::endl;
        info = -1;
    }

    // Bring factorized data back to host
    if (info == 0)
        exec_queue.copy(A_d, A_h, size_a).wait();

    // Check that the computations completed successfully
    std::cout << "getrf_batch " << ((info == 0) ? "ran OK" : "FAILED") << std::endl;

    // Clean up
    sycl::free(A_h, exec_queue);
    sycl::free(A_d, exec_queue);
    sycl::free(A_ptrs_d, exec_queue);
    sycl::free(temp_ptrs, exec_queue);
    sycl::free(getrfnp_batch_scratchpad, exec_queue);
}


//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {

    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Batched LU Factorization Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Computes LU Factorization A = L * U (without pivoting)" << std::endl;
    std::cout << "# for multiple groups of general dense matrices A." << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   getrfnp_batch (group API)" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;

}


//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, lu_batch example
//  runs with all supported data types
//
int main(int argc, char **argv) {

    print_example_banner();

    // Find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    for(auto &deviceType : listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if(myDevIsFound) {
            std::cout << std::endl << "Running getrf_batch examples on " << sycl_device_names[deviceType] << "." << std::endl;

            std::cout << "Running with single precision real data type:" << std::endl;
            run_getrfnp_batch_group_example<float>(myDev);

            std::cout << "Running with single precision complex data type:" << std::endl;
            run_getrfnp_batch_group_example<std::complex<float>>(myDev);

            if (isDoubleSupported(myDev)) {
                std::cout << "Running with double precision real data type:" << std::endl;
                run_getrfnp_batch_group_example<double>(myDev);

                std::cout << "Running with double precision complex data type:" << std::endl;
                run_getrfnp_batch_group_example<std::complex<double>>(myDev);
            } else {
                std::cout << "Double precision not supported on this device " << std::endl;
                std::cout << std::endl;
            }

        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
          }
    }
    mkl_free_buffers();
    return 0;
}
