/*
** A reduce_scatter example
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>



int
main(int argc, char *argv[])
{

int my_rank, nproc;
int *rcv_counts;
int total;
int *my_value;
int *result;
int i, j;



    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &nproc);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

    /* Allocate memory for the receive counts */
    rcv_counts= (int *)malloc(sizeof(int) * nproc);
    if (rcv_counts == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

    total= 0;
    for (i= 0; i < nproc; i++)   {
        rcv_counts[i]= i + 1;
        total= total + i + 1;
    }

    if (my_rank == 0)   {
        printf("Vector length is %d elements\n", total);
    }

    /* Allocate memory for the input vector */
    my_value= (int *)malloc(sizeof(int) * total);
    if (my_value == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

    /* Each node allocate only enough memory for what it needs. */
    result= (int *)malloc(sizeof(int) * rcv_counts[my_rank]);
    if (result == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

    /* Fill in our contribution */
    for (i= 0; i < total; i++)   {
        my_value[i]= i;
    }

    if (my_rank == 0)   {
        printf("Before\n");
    }
    for (i= 0; i < nproc; i++)   {
        if (i == my_rank)   {
            printf("Node %3d: ", my_rank);
            for (j= 0; j < total; j++)   {
                printf("%3d ", my_value[j]);
            }
            printf("\n");
        }
        MPI_Barrier(MPI_COMM_WORLD);
    }

    MPI_Reduce_scatter(my_value, result, rcv_counts, MPI_INT,
            MPI_SUM, MPI_COMM_WORLD);

    if (my_rank == 0)   {
        printf("After\n");
    }
    for (i= 0; i < nproc; i++)   {
        if (i == my_rank)   {
            printf("Node %3d: ", my_rank);
            for (j= 0; j < rcv_counts[my_rank]; j++)   {
                printf("%3d ", result[j]);
            }
            printf("\n");
        }
        MPI_Barrier(MPI_COMM_WORLD);
    }


    MPI_Finalize();
    return 0;

}  /* end of main() */
