/*
** Global sumation using recursive doubling
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <mpi.h>


int gsum(int summand, int root, MPI_Comm comm);




/* Calculate 2^n */
static int
pow2(int n)
{

int i;
int res;


    /* This is valid for this program where 0 >- n */
    if (n < 1)   {
	return 0;
    }

    res= 1;
    for (i= 0; i < n; i++)   {
	res= 2 * res;
    }

    return res;

}  /* end of pow2() */


/* Calculate log2 n */
static int
log2i(int n)
{

int res, cur;


    res= 0;
    cur= n;
    while (cur > 1)   {
	res++;
	cur= cur / 2;
    }

    return res;

}  /* end of log2i() */


int
gsum(int summand, int root, MPI_Comm comm)
{

int my_rank, nproc;
int i;
int count= 1;
int tag= 111;
MPI_Request request;
int rounds;
int result;
int mask, select;
int sum;


    MPI_Comm_size(MPI_COMM_WORLD, &nproc);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

    if ((root < 0) || (root >= nproc))   {
	fprintf(stderr, "Root outside range!\n");
	MPI_Abort(MPI_COMM_WORLD, -1);
    }

    /* This only works on power of 2 nodes */
    if (pow2(log2i(nproc)) != nproc)   {
	/* This function requires 2^n nodes */
	fprintf(stderr, "Need power of 2 nodes!\n");
	MPI_Abort(MPI_COMM_WORLD, -1);
    }


    if (nproc <= 1)   {
	return summand;
    }

    rounds= log2i(nproc);
    mask= 0;
    select= 1;
    sum= summand;

    for (i= 0; i < rounds; i++)   {
	if ((my_rank & mask) == 0)   {
	    /* I participate in this round */
	    if (my_rank & select)   {
		/* I'm the sender */
		MPI_Send(&sum, count, MPI_INT, my_rank & ~select, tag, MPI_COMM_WORLD);
	    } else   {
		/* I'm the receiver */
		MPI_Recv(&result, count, MPI_INT, my_rank | select, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
		sum= sum + result;
	    }
	}
	mask= (mask << 1) | 1;
	select= select << 1;
    }


    /* Send the result to root */
    if (my_rank == 0)   {
	MPI_Isend(&sum, count, MPI_INT, root, tag, MPI_COMM_WORLD, &request);
    }

    if (my_rank == root)   {
	/* Wait for the result */
	MPI_Recv(&result, count, MPI_INT, 0, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
	return result;
    }

    if (my_rank == 0)   {
	MPI_Wait(&request, MPI_STATUS_IGNORE);
    }

    return 0;

}  /* end of gsum() */
