/*
** A reduce example with a user-defined function
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <mpi.h>

#define TRUE	(1)

/*
** Local functions
*/
void my_func(void *invec, void *inoutvec, int *len, MPI_Datatype *dt);


int
main(int argc, char *argv[])
{

int my_rank, nproc;
int my_value;
float my_valuef;
int result;
float resultf;
int root;
int count;
int calc;
float calcf;
int i;
MPI_Op my_op;



    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &nproc);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

    /* Each node contributes its own node number */
    my_value= my_rank;
    my_valuef= (float)my_rank;
    calc= 0;
    result= 0;
    calcf= 0.0;
    resultf= 0.0;
    root= 0;
    count= 1;

    MPI_Op_create(my_func, TRUE, &my_op);

    /* Do an integer reduction */
    MPI_Reduce(&my_value, &result, count, MPI_INT, my_op, root,
            MPI_COMM_WORLD);

    /* Do a floating-point reduction */
    MPI_Reduce(&my_valuef, &resultf, count, MPI_FLOAT, my_op, root,
            MPI_COMM_WORLD);

    printf("Node %3d: My result is %d and %f\n", my_rank, result, resultf);
    if (my_rank == root)   {
	for (i= 0; i < nproc; i++)   {
	    calc= calc + i;
	}
	printf("Expected int   result on node %d is %d.\n", root, calc);

	for (i= 0; i < nproc; i++)   {
	    calcf= calcf + i;
	}
	printf("Expected float result on node %d is %f.\n", root, calcf);
    }

    MPI_Op_free(&my_op);
    MPI_Finalize();
    return 0;

}  /* end of main() */



void
my_func(void *invec, void *inoutvec, int *len, MPI_Datatype *dt)
{

int i;
int *u, *v, *w;
float *uf, *vf, *wf;


    if (*dt == MPI_INT)   {
        u= invec;
        v= inoutvec;
        w= inoutvec;

        for (i= 0; i < *len; i++)   {
            w[i]= u[i] + v[i];
        }
    } else if (*dt == MPI_FLOAT)   {
        uf= invec;
        vf= inoutvec;
        wf= inoutvec;

        for (i= 0; i < *len; i++)   {
            wf[i]= uf[i] + vf[i];
        }
    } else   {
        /* A data type we cannot handle */
	MPI_Abort(MPI_COMM_WORLD, -1);
    }

}  /* end of my_func() */
