/*
** A test wrapper for Homework 3, exercise 1.
** Rolf Riesen, 2008
*/
#include <stdio.h>
#include <string.h>     /* For memset() */
#include <mpi.h>

extern void selective_sum(void *invec, void *inoutvec, int *len, MPI_Datatype *dt);


#define TRUE    (1)
#define FALSE   (0)
#define COUNT   (10)

typedef struct intint_t   {
    int value;
    int flag;
} intint_t;

typedef struct floatint_t   {
    float value;
    int flag;
} floatint_t;

typedef struct doubleint_t   {
    double value;
    int flag;
} doubleint_t;



int
main(int argc, char *argv[])
{

int my_rank, nproc;
int root;
int i;
MPI_Op ssum;

intint_t intint[COUNT];
floatint_t floatint[COUNT];
doubleint_t doubleint[COUNT];

intint_t res_intint[COUNT];
floatint_t res_floatint[COUNT];
doubleint_t res_doubleint[COUNT];


    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &nproc);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);


    /* Register the function */
    MPI_Op_create(selective_sum, TRUE, &ssum);

    /* Initialize the local variables */
    for (i= 0; i < COUNT; i++)   {
        intint[i].value= i;
        intint[i].flag= TRUE;
        floatint[i].value= i;
        floatint[i].flag= TRUE;
        doubleint[i].value= i;
        doubleint[i].flag= TRUE;

        res_intint[i].value= 0;
        res_intint[i].flag= 0;
        res_floatint[i].value= 0;
        res_floatint[i].flag= 0;
        res_doubleint[i].value= 0;
        res_doubleint[i].flag= 0;
    }

    root= 0;
    MPI_Reduce(intint, res_intint, COUNT, MPI_2INT, ssum, root, MPI_COMM_WORLD);

    if (my_rank == root)   {
        printf("[%2d] Test 0: ", my_rank);
        for (i= 0; i < COUNT; i++)   {
            printf("  %4d", res_intint[i].value);
        }
        printf("\n");
    }

    /* Test 1 */
    memset(res_intint, 0, COUNT * sizeof(intint_t));
    for (i= 0; i < COUNT; i++)   {
        intint[i].value= 1;
        intint[i].flag= TRUE;
    }


    MPI_Reduce(intint, res_intint, COUNT, MPI_2INT, ssum, root, MPI_COMM_WORLD);

    if (my_rank == root)   {
        printf("[%2d] Test 1: ", my_rank);
        for (i= 0; i < COUNT; i++)   {
            printf("  %4d", res_intint[i].value);
        }
        printf("\n");
    }


    /* Test 2 */
    for (i= 1; i < COUNT; i= i + 2)   {
        floatint[i].flag= FALSE;
    }


    root= nproc - 1;
    MPI_Reduce(floatint, res_floatint, COUNT, MPI_FLOAT_INT, ssum, root, MPI_COMM_WORLD);

    if (my_rank == root)   {
        printf("[%2d] Test 2: ", my_rank);
        for (i= 0; i < COUNT; i++)   {
            printf("  %4.1f", res_floatint[i].value);
        }
        printf("\n");
    }


    /* Test 3 */
    if (my_rank < COUNT)   {
        doubleint[my_rank].flag= FALSE;
    }


    root= nproc / 2;
    MPI_Reduce(doubleint, res_doubleint, COUNT, MPI_DOUBLE_INT, ssum, root, MPI_COMM_WORLD);

    if (my_rank == root)   {
        printf("[%2d] Test 3: ", my_rank);
        for (i= 0; i < COUNT; i++)   {
            printf("  %4.1f", res_doubleint[i].value);
        }
        printf("\n");
    }


    /* Test 4 */
    if (my_rank == root)   {
        printf("What happens if I use an illegal type?\n");
        fflush(stdout);
    }
    MPI_Barrier(MPI_COMM_WORLD);
    MPI_Reduce(intint, res_intint, COUNT, MPI_LONG_INT, ssum, root, MPI_COMM_WORLD);


    MPI_Finalize();
    return 0;

}  /* end of main() */
