/*
** Pthreads Matrix Multiply
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <stdlib.h>	/* For strtol(), exit(), malloc() */
#include <unistd.h>	/* For getopt() */
#include <string.h>	/* For memset() */
#include <errno.h>	/* For perror() */
#include <math.h>	/* For sqrt() */
#include <pthread.h>	/* For pthread_*() */


/* Constants */
#define DEFAULT_NUM_THREADS	(4)
#define DEFAULT_N		(720)	/* Matrix size */
#define FALSE			(0)
#define TRUE			(1)


/* Local functions */
static float **alloc_mat(int n);
static void fill_mat(float **A, int n);
static void print_mat(float **A, int n, char *name);
static void calc(int x1, int x2, int y1, int y2, float **A, float **B, float **C, int n);
static void *wrapper_func(void *arg);
static void usage(char *pname);


/*
** A structure so we can pass the sub-matrix values we want
*/
typedef struct sub_matrix_t   {
    int x1;
    int x2;
    int y1;
    int y2;
    float **A;
    float **B;
    float **C; 
    int n;
} sub_matrix_t;



int
main(int argc, char *argv[])
{

extern char *optarg;
int ch, error;

int nproc, nproc2;
int i, j, k, rc;
int rows;
int n;
sub_matrix_t *args;
pthread_t *thread;
float **A, **B, **C;


 
    /* Initialize some variables */
    nproc= DEFAULT_NUM_THREADS;		/* Default number of threads */
    n= DEFAULT_N;                       /* Default matrix size */
    opterr= 0;	/* Don't let getopt print an error msg */
    error= FALSE;


    while ((ch= getopt(argc, argv, "p:n:")) != EOF)   {
	switch (ch)   {
	    case 'p':
		nproc= strtol(optarg, (char **)NULL, 0);
		if (nproc < 1)   {
		    error= TRUE;
		}
		break;
	    case 'n':
		n= strtol(optarg, (char **)NULL, 0);
		if (n < 1)   {
		    error= TRUE;
		}
		break;
	    default:
		error= TRUE;
	}
    }

    if (error)   {
	usage(argv[0]);
	exit(-1);
    }


    /*
    ** To keep things simple, we assume nproc is n^2
    ** and that n is evenly divisible by sqrt(nproc)
    */
    nproc2= (int)sqrt((double)nproc);
    if ((nproc2 * nproc2) != nproc)   {
        fprintf(stderr, "Number of threads must be a square number. Sorry!\n");
        exit(-1);
    }

    rows= n / nproc2;
    if ((rows * nproc2) != n)   {
        fprintf(stderr, "Number of rows and cols for each thread must be equal.\n");
        fprintf(stderr, "    Attempted %d rows for %d threads.\n", n, nproc2);
        exit(-1);
    }

    printf("Assigning %d * %d threads to calculate %d * %d matrix\n",
        nproc2, nproc2, n, n);


    /* Allocate storage for the thread handles */
    thread= (pthread_t *)malloc(nproc * sizeof(pthread_t));
    if (thread == NULL)   {
	fprintf(stderr, "Out of memory!\n");
	exit(-1);
    }

    /* Allocate storage for the thread arguments */
    args= (sub_matrix_t *)malloc(nproc * sizeof(sub_matrix_t));
    if (args == NULL)   {
	fprintf(stderr, "Out of memory!\n");
	exit(-1);
    }

    printf("main() This program is running with %d threads\n", nproc);

    A= alloc_mat(n);
    B= alloc_mat(n);
    C= alloc_mat(n);

    fill_mat(A, n);
    fill_mat(B, n);

    print_mat(A, n, "A");
    print_mat(B, n, "B");


    printf("Matricies allocated and initialized\n");
    k= 0;
    for (i= 0; i < nproc2; i++)   {
        for (j= 0; j < nproc2; j++)   {
            /* Set up the command struct for the next thread */
            args[k].x1= i * rows;
            args[k].x2= ((i + 1) * rows) - 1;
            args[k].y1= j * rows;
            args[k].y2= ((j + 1) * rows) - 1;
            args[k].A= A;
            args[k].B= B;
            args[k].C= C;
            args[k].n= n;

            rc= pthread_create(&thread[k], NULL, wrapper_func,
                    (void *)&args[k]);
            if (rc != 0)   {
                perror("pthread_create() failed");
                exit(-1); 
            }
            k++;
        }
    }

    for (k= 0; k < nproc; k++)   {
	pthread_join(thread[k], NULL);
    }


    printf("main() Done\n");

    print_mat(C, n, "C");

    return 0; 

}  /* end of main() */



static void *
wrapper_func(void *arg)
{

sub_matrix_t *args;


    args= (sub_matrix_t *)arg;
    printf("calc(%d, %d, %d, %d, %p, %p, %p, %d)\n",
            args->x1, args->x2, args->y1, args->y2, args->A,
            args->B, args->C, args->n);

    calc(args->x1, args->x2, args->y1, args->y2, args->A,
            args->B, args->C, args->n);
    return NULL;

}  /* end of wrapper_func() */



/*
** For the sub-array x1,y1 to x2,y2 in matrix C, calculate each
** element of the matrix multiply. All matrices are of size n * n.
*/
static void
calc(int x1, int x2, int y1, int y2, float **A, float **B, float **C, int n)
{

int i, j, k;
float sum;


    for (i= y1; i <= y2; i++)   {
	for (j= x1; j <= x2; j++)   {
	    sum= 0.0;
	    for (k= 0; k < n; k++)   {
		sum= sum + A[i][k] * B[k][j];
	    }
	    C[i][j]= sum;
	}
    }

    return;

}  /* end of calc() */




/*
** Allocate memory for a n * n matrix of floats
*/
static float **
alloc_mat(int n)
{

int array_size, ptr_size;
float **array;
float *ptr;
int i;


    /* Size for the data: n^2 float */
    array_size= n * n * sizeof(float);

    /* Size for the pointers: n pointers to char */
    ptr_size= n * sizeof(float *);

    /* Allocate all memory in one chunk */
    array= (float **)malloc(array_size + ptr_size);
    if (array == NULL)   {
        fprintf(stderr, "Memory allocation failed!\n");
	exit(-1);
    }

    /* Clear it */
    memset(array, 0, array_size + ptr_size);

    /* Fill in the pointer array */
    ptr= (float *)(array + n);
    for (i= 0; i < n; i++)   {
        array[i]= ptr;
        ptr= ptr + n;
    }

    return array;

}  /* end of alloc_mat() */



static void
fill_mat(float **A, int n)
{

float value;
int i, j;


    value= 0.0;
    for (j= 0; j < n; j++)   {
	for (i= 0; i < n; i++)   {
	    A[i][j]= value;
	    value= value + 1.0;
	}
    }

}  /* end of fill_mat() */



static void
print_mat(float **A, int n, char *name)
{

float value;
int i, j;


    if (n > 16)   {
        /* Too big to print */
        return;
    }

    printf("\nMatrix %s\n", name);
    value= 0.0;
    for (j= 0; j < n; j++)   {
	for (i= 0; i < n; i++)   {
	    printf("%5.1f ", A[i][j]);
	}
	printf("\n");
    }

}  /* end of fill_mat() */



static void
usage(char *pname)
{

    fprintf(stderr, "Usage: %s [-p num] [-n num]\n", pname);
    fprintf(stderr, "           -p num   Create num threads\n");
    fprintf(stderr, "                    Default is %d.\n", DEFAULT_NUM_THREADS);
    fprintf(stderr, "           -n num   Size of matricies\n");
    fprintf(stderr, "                    Default is %d * %d.\n",
            DEFAULT_N, DEFAULT_N);

}  /* end of usage() */
