/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*   Content : Intel(R) oneAPI Math Kernel Library (oneMKL) Sparse BLAS C OpenMP
*             offload example for mkl_sparse_order with async execution.
*
********************************************************************************
*
* Consider the matrix A (see 'Sparse Storage Formats for Sparse BLAS Level 2
* and Level 3 in the  Intel oneMKL Reference Manual')
*
*                 |   1       -1      0   -3     0   |
*                 |  -2        5      0    0     0   |
*   A    =        |   0        0      4    6     4   |,
*                 |  -4        0      2    7     0   |
*                 |   0        8      0    0    -5   |
*
*  The matrix A is represented in a one-based compressed sparse row (CSR) storage
*  scheme with three arrays (see 'Sparse Matrix Storage Schemes' in the
*  Intel oneMKL Reference Manual) with arrays sorted by the column
*  indices over each row as follows:
*
*         values    = ( 1 -1 -3 -2  5  4  6  4 -4  2  7  8 -5 )
*         columns   = ( 1  2  4  1  2  3  4  5  1  3  4  2  5 )
*         row_index = ( 1        4     6        9       12    14 )
*
*  Two of the many unsorted CSR representations of the above matrix may be as
*  follows:
*
*         values    = (-1  1 -3  5 -2  4  4  6  2  7 -4  8 -5 )
*         columns   = ( 2  1  4  2  1  5  3  4  3  4  1  2  5 )
*         row_index = ( 1        4     6        9       12    14 )
*
*  and:
*
*         values    = (-3 -1  1 -2  5  6  4  4  2 -4  7 -5  8 )
*         columns   = ( 4  2  1  1  2  4  5  3  3  1  4  5  2 )
*         row_index = ( 1        4     6        9       12    14 )
*
*  Note that `columns` and `values` arrays above are not sorted over each row;
*  the CSR format by definition has its `row_index` array sorted.
*
*  This test performs the following operations :
*
*       Sort the `columns` and `values` arrays of two CSR matrices using
*       mkl_sparse_order omp offload with async execution.
*
*       In Sparse BLAS, we use the inspector/executor paradigm and the execution
*       stage can be performed asynchronously provided that any data dependencies
*       between execution tasks are also respected. However, the creation
*       of the sparse matrix handle, the inspection/analysis stage and the
*       destruction of the matrix handle must be done synchronously. Additionally,
*       if the execution tasks are performed asynchronously, a
*       "#pragma omp taskwait" must be added before the call to mkl_sparse_destroy
*       to ensure all execution tasks are completed before destroying the handle.
*       Adding the nowait clause to the mkl_sparse_?_create_csr() or
*       mkl_sparse_destroy() calls may result in incorrect behavior.
*
********************************************************************************
*/
#include <assert.h>
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>

#include "common_for_sparse_examples.h"
#include "mkl.h"
#include "mkl_omp_offload.h"

// special formatting for MKL_INT and floating point printing below 
// with at least 3 digits and accounting for aligned space for negatives
#ifdef MKL_ILP64
#define INT_FORMAT "% 3lld "
#else
#define INT_FORMAT "% 3d "
#endif

#define FP_FORMAT "% 3.1f "

int main()
{
//*******************************************************************************
//     Declaration and initialization of parameters for sparse representation of
//     the matrix A in the compressed sparse row format:
//*******************************************************************************
#define M 5
#define N 5
#define NNZ 13

    // // Structures with sparse matrices stored in CSR format
    sparse_matrix_t csr, csrA, csrB;
    //*******************************************************************************
    //    Declaration of local variables:
    //*******************************************************************************
    MKL_INT i;

    double *values_sorted     = (double *)mkl_malloc(sizeof(double) * NNZ, 64);
    MKL_INT *columns_sorted   = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ, 64);
    MKL_INT *row_index_sorted = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (N + 1), 64);

    double *values_A     = (double *)mkl_malloc(sizeof(double) * NNZ, 64);
    MKL_INT *columns_A   = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ, 64);
    MKL_INT *row_index_A = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (N + 1), 64);

    double *values_B     = (double *)mkl_malloc(sizeof(double) * NNZ, 64);
    MKL_INT *columns_B   = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ, 64);
    MKL_INT *row_index_B = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (N + 1), 64);

    const int num_pointers = 9;
    void *pointer_array[num_pointers];
    pointer_array[0] = values_sorted;
    pointer_array[1] = columns_sorted;
    pointer_array[2] = row_index_sorted;
    pointer_array[3] = values_A;
    pointer_array[4] = columns_A;
    pointer_array[5] = row_index_A;
    pointer_array[6] = values_B;
    pointer_array[7] = columns_B;
    pointer_array[8] = row_index_B;

    if (!values_sorted || !columns_sorted || !row_index_sorted ||
        !values_A      || !columns_A      || !row_index_A      ||
        !values_B      || !columns_B      || !row_index_B
       ) {
        free_allocated_memories(pointer_array, num_pointers);
        return 1;
    }

    //*******************************************************************************
    //    Sparse representation of the sorted matrix and unsorted A and B matrices
    //*******************************************************************************
    {
        double init_values_sorted[NNZ]       = { 1,-1,-3,-2, 5, 4, 6, 4,-4, 2, 7, 8,-5};
        MKL_INT init_columns_sorted[NNZ]     = { 1, 2, 4, 1, 2, 3, 4, 5, 1, 3, 4, 2, 5};
        MKL_INT init_row_index_sorted[N + 1] = { 1,       4,    6,       9,      12,   14};
        double init_values_A[NNZ]            = {-1, 1,-3, 5,-2, 4, 4, 6, 2, 7,-4, 8,-5};
        MKL_INT init_columns_A[NNZ]          = { 2, 1, 4, 2, 1, 5, 3, 4, 3, 4, 1, 2, 5};
        MKL_INT init_row_index_A[N + 1]      = { 1,       4,    6,       9,      12,   14};
        double init_values_B[NNZ]            = {-3,-1, 1,-2, 5, 6, 4, 4, 2,-4, 7,-5, 8};
        MKL_INT init_columns_B[NNZ]          = { 4, 2, 1, 1, 2, 4, 5, 3, 3, 1, 4, 5, 2};
        MKL_INT init_row_index_B[N + 1]      = { 1,       4,    6,       9,      12,   14};

        for (i = 0; i < NNZ; i++) {
            values_sorted[i]  = init_values_sorted[i];
            columns_sorted[i] = init_columns_sorted[i];
            values_A[i]  = init_values_A[i];
            columns_A[i] = init_columns_A[i];
            values_B[i]  = init_values_B[i];
            columns_B[i] = init_columns_B[i];
        }
        for (i = 0; i < N + 1; i++) {
            row_index_sorted[i] = init_row_index_sorted[i];
            row_index_A[i] = init_row_index_A[i];
            row_index_B[i] = init_row_index_B[i];
        }

    }

    printf("\n EXAMPLE PROGRAM FOR mkl_sparse_order omp_offload async\n");
    printf("---------------------------------------------------\n");
    printf("\n");
    printf("   INPUT DATA FOR mkl_sparse_order omp offload async\n");
    printf("   WITH GENERAL SPARSE MATRICES\n");
    printf("   SORTED csr:\n");
    printf("   row_index_sorted:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_sorted[i]);
    printf("\n   columns_sorted:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_sorted[i]);
    printf("\n   values_sorted:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_sorted[i]);
    printf("\n   UNSORTED csrA:                 \n");
    printf("   row_index_A:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_A[i]);
    printf("\n   columns_A:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_A[i]);
    printf("\n   values_A:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_A[i]);
    printf("\n   UNSORTED csrB:                 \n");
    printf("   row_index_B:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_B[i]);
    printf("\n   columns_B:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_B[i]);
    printf("\n   values_B:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_B[i]);
    printf("\n---------------------------------------------------\n");
    fflush(stdout);

    sparse_status_t ie_status;

    // Create handle with matrix stored in CSR format
    printf("Create CSR matrix\n");
    ie_status = mkl_sparse_d_create_csr(&csr, SPARSE_INDEX_BASE_ONE,
                                        N, // number of rows
                                        M, // number of cols
                                        row_index_sorted, row_index_sorted + 1,
                                        columns_sorted, values_sorted);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_d_create_csr: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    // Sort column indices and values
    printf("Call mkl_sparse_order()\n");
    ie_status = mkl_sparse_order(csr);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_order: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    // Release matrix handle and deallocate matrix
    printf("Destroy csr matrix\n");
    ie_status = mkl_sparse_destroy(csr);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_destroy: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    printf("                                   \n");
    printf("   OUTPUT DATA FOR mkl_sparse_order()\n");
    printf("   row_index_sorted:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_sorted[i]);
    printf("\n   columns_sorted:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_sorted[i]);
    printf("\n   values_sorted:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_sorted[i]);
    printf("\n---------------------------------------------------\n");
    fflush(stdout);

    const int devNum = 0;

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
    printf("Using OpenMP 5.1. dispatch construct.\n"); fflush(0);
#else
    printf("Using OpenMP offload 'target variant dispatch' construct.\n"); fflush(0);
#endif

    sparse_matrix_t csrA_gpu, csrB_gpu;

    sparse_status_t status_create_A, status_create_B;
    sparse_status_t status_order_A, status_order_B;
    sparse_status_t status_destroy_A, status_destroy_B;

// call create_csr/order/destroy via omp_offload.
#pragma omp target data map(to: row_index_A[0:N+1], row_index_B[0:N+1]) \
                        map(tofrom: columns_A[0:NNZ], values_A[0:NNZ],  \
                                    columns_B[0:NNZ], values_B[0:NNZ])  \
                        device(devNum)
    {
        printf("Create CSR matrices via omp_offload\n");
        fflush(stdout);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum) use_device_ptr(row_index_A, columns_A, values_A)
#endif
        status_create_A = mkl_sparse_d_create_csr(&csrA_gpu, SPARSE_INDEX_BASE_ONE, N, M,
                                                  row_index_A, row_index_A + 1, columns_A, values_A);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum) use_device_ptr(row_index_B, columns_B, values_B)
#endif
        status_create_B = mkl_sparse_d_create_csr(&csrB_gpu, SPARSE_INDEX_BASE_ONE, N, M,
                                                  row_index_B, row_index_B + 1, columns_B, values_B);

        printf("Compute mkl_sparse_order via omp_offload\n");
        fflush(stdout);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum) nowait
#else
#pragma omp target variant dispatch device(devNum) nowait
#endif
        status_order_A = mkl_sparse_order(csrA_gpu);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum) nowait
#else
#pragma omp target variant dispatch device(devNum) nowait
#endif
        status_order_B = mkl_sparse_order(csrB_gpu);

#pragma omp taskwait

        printf("Destroy the CSR matrices via omp_offload\n");
        fflush(stdout);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum)
#endif
        status_destroy_A = mkl_sparse_destroy(csrA_gpu);

#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum)
#endif
        status_destroy_B = mkl_sparse_destroy(csrB_gpu);
    }

    int flps_per_value = 0;
    int status1 = 0, status2 = 0, status3 = 0;
    int status4 = 0, status5 = 0, status6 = 0;

    int status_offload = status_create_A  | status_create_B |
                         status_order_A   | status_order_B  |
                         status_destroy_A | status_destroy_B;
    if (status_offload != 0) {
        printf("\tERROR: status_create_A = %d, status_create_B = %d, "
                        "status_order_A = %d, status_order_B = %d, "
                        "status_destroy_A = %d\n, status_destroy_B = %d",
               status_create_A,  status_create_B,
               status_order_A,   status_order_B,
               status_destroy_A, status_destroy_B);
        goto cleanup;
    }

    printf("   OUTPUT DATA FOR csrA mkl_sparse_order_omp_offload async execution.\n");
    printf("   sorted row_index_A:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_A[i]);
    printf("\n   sorted columns_A:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_A[i]);
    printf("\n   sorted values_A:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_A[i]);
    printf("\n   OUTPUT DATA FOR csrB mkl_sparse_order_omp_offload async execution.\n");
    printf("   sorted row_index_B:\n        ");
    for (i = 0; i < N+1; i++) printf(INT_FORMAT, row_index_B[i]);
    printf("\n   sorted columns_B:\n        ");
    for (i = 0; i < NNZ; i++) printf(INT_FORMAT, columns_B[i]);
    printf("\n   sorted values_B:\n        ");
    for (i = 0; i < NNZ; i++) printf(FP_FORMAT, values_B[i]);
    printf("\n---------------------------------------------------\n");
    fflush(stdout);

    status1 = validation_result_integer(row_index_sorted, row_index_A, N+1);
    status2 = validation_result_integer(row_index_sorted, row_index_B, N+1);
    status3 = validation_result_integer(columns_sorted, columns_A, NNZ);
    status4 = validation_result_integer(columns_sorted, columns_B, NNZ);
    status5 = validation_result_double(values_sorted, values_A, NNZ, flps_per_value);
    status6 = validation_result_double(values_sorted, values_B, NNZ, flps_per_value);

cleanup:
    free_allocated_memories(pointer_array, num_pointers);

    const int status_all = status1 | status2 | status3 |
                           status4 | status5 | status6 |
                           status_offload;
    printf("Test %s\n", status_all == 0 ? "PASSED" : "FAILED");
    fflush(stdout);

    return status_all;
}
