/*
** Distribute a lower-triangular matrix row-wise to all nodes
** using MPI_Scaterv()
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>



int
main(int argc, char *argv[])
{

int my_rank, nproc;
int *matrix;
int root;
int i, j;
int value;
int *snd_counts;
int *displs;
int *result;



    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &nproc);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);


    /* Pick some defaults */
    root= 1;
    matrix= NULL;


    if (my_rank == root)   {
        /* Allocate memory on the root node for the matrix */
        matrix= (int *)malloc(sizeof(int) * nproc * nproc);
        if (matrix == NULL)   {
            MPI_Abort(MPI_COMM_WORLD, -1);
        }

        /* Fill the matrix with some data */
        value= 0;
        for (i= 0; i < nproc; i++)   {
            for (j= 0; j < nproc; j++)   {
                matrix[(i * nproc) + j]= value;
                value= value + 1;
            }
        }
    }


    /* Allocate memory for the receive counts */
    snd_counts= (int *)malloc(sizeof(int) * nproc);
    if (snd_counts == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

    /* Allocate memory for the displacements */
    displs= (int *)malloc(sizeof(int) * nproc);
    if (displs == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

    /* Calculate the send displacements and send counts */
    for (i= 0; i < nproc; i++)   {
        snd_counts[i]= i + 1;
        displs[i]= i * nproc;
    }

    /* Allocate memory for the recv buffer */
    result= (int *)malloc(sizeof(int) * snd_counts[my_rank]);
    if (result == NULL)   {
        MPI_Abort(MPI_COMM_WORLD, -1);
    }


    if (my_rank == root)   {
        printf("Before on root node %d\n", root);
        for (i= 0; i < nproc; i++)   {
            for (j= 0; j < nproc; j++)   {
                printf("%3d ", matrix[(i * nproc) + j]);
            }
            printf("\n");
        }
    }

    MPI_Scatterv(matrix, snd_counts, displs, MPI_INT, result,
            snd_counts[my_rank], MPI_INT, root, MPI_COMM_WORLD);

    if (my_rank == 0)   {
        printf("After\n");
    }
    for (i= 0; i < nproc; i++)   {
        if (i == my_rank)   {
            printf("Node %3d: ", my_rank);
            for (j= 0; j < snd_counts[my_rank]; j++)   {
                printf("%3d ", result[j]);
            }
            printf("\n");
        }
        MPI_Barrier(MPI_COMM_WORLD);
    }


    MPI_Finalize();
    return 0;

}  /* end of main() */
