/*
** Implement a Barrier using a Conditions Variable
** Copyright Rolf Riesen 2008
**
*/
#include <stdio.h>
#include <stdlib.h>	/* For random(), srandom(), strtol(), exit(), malloc() */
#include <unistd.h>	/* For getopt() */
#include <sys/time.h>	/* For gettimeofday() */
#include <string.h>	/* For strerror() */
#include <errno.h>	/* For perror() */
#include <pthread.h>	/* For pthread_*() */


/* Constants */
#define DEFAULT_NUM_THREADS	(4)
#define FALSE			(0)
#define TRUE			(1)


/* (Shared) Globals */
pthread_mutex_t barrier_lock= PTHREAD_MUTEX_INITIALIZER;
pthread_cond_t barrier_cond_var= PTHREAD_COND_INITIALIZER;
int barrier_seq;
int barrier_count[2];
int nproc;



/* Local functions */
static void *do_work(void *arg);
static void barrier(void);
static void usage(char *pname);
static double dclock(void);



int
main(int argc, char *argv[])
{

extern char *optarg;
int ch, error;

int i, rc;
int *args;
pthread_t *thread;


 
    /* Initialize some variables */
    nproc= DEFAULT_NUM_THREADS;		/* Default number of threads */
    opterr= 0;	/* Don't let getopt print an error msg */
    error= FALSE;


    while ((ch= getopt(argc, argv, "p:")) != EOF)   {
	switch (ch)   {
	    case 'p':
		nproc= strtol(optarg, (char **)NULL, 0);
		if (nproc < 1)   {
		    error= TRUE;
		}
		break;
	    default:
		error= TRUE;
	}
    }

    if (error)   {
	usage(argv[0]);
	exit(-1);
    }


    printf("Using %d threads.\n", nproc);

    /* Allocate storage for the thread handles */
    thread= (pthread_t *)malloc(nproc * sizeof(pthread_t));
    if (thread == NULL)   {
	fprintf(stderr, "Out of memory!\n");
	exit(-1);
    }

    /* Allocate storage for the thread arguments */
    args= (int *)malloc(nproc * sizeof(int));
    if (args == NULL)   {
	fprintf(stderr, "Out of memory!\n");
	exit(-1);
    }

    /* Initialize the barrier */
    barrier_seq= 0;
    barrier_count[0]= nproc;
    barrier_count[1]= nproc;


    for (i= 0; i < nproc; i++)   {
	args[i]= i;
	rc= pthread_create(&thread[i], NULL, do_work,
		(void *)&args[i]);
	if (rc != 0)   {
	    perror("pthread_create() failed");
	    exit(-1); 
	}
    }

    for (i= 0; i < nproc; i++)   {
	pthread_join(thread[i], NULL);
    }


    return 0; 

}  /* end of main() */



static void *
do_work(void *arg)
{

int my_rank;
int i, work;
double t1, t2, t3, t4;
double delay;


    my_rank= *((int *)arg);

    /* Do a random amount of work */
    srandom(my_rank);
    work= random() % 100000;
    for (i= 0; i < work; i++)   {
	/* Just waste some time */
	srandom(i);
    }

    /* First barrier */
    t1= dclock();
    barrier();
    t2= dclock();

    work= random() % 100000;
    for (i= 0; i < work; i++)   {
	/* Just waste some time */
	srandom(i);
    }

    /* Second barrier */
    t3= dclock();
    barrier();
    t4= dclock();

    delay= ((t2 - t1) + (t4 - t3)) / 1000000.0;
    printf("[%3d] spent %9.6fs in barrier.\n", my_rank, delay);

    return NULL;

}  /* end of do_work() */



static void
barrier(void)
{

int current_seq;


    pthread_mutex_lock(&barrier_lock);
    current_seq= barrier_seq;

    barrier_count[current_seq]--;
    if (barrier_count[current_seq] == 0)   {
	/* I must be the last thread to enter. Wake up all others */
	printf("Waking everybody up!\n");
	barrier_seq= (barrier_seq + 1) % 2;
	barrier_count[barrier_seq]= nproc;
	pthread_cond_broadcast(&barrier_cond_var);
    } else   {
	while (barrier_count[current_seq] != 0)   {
	    pthread_cond_wait(&barrier_cond_var, &barrier_lock);
	}
    }
    pthread_mutex_unlock(&barrier_lock);

}  /* end of barrier() */



static void
usage(char *pname)
{

    fprintf(stderr, "Usage: %s [-p num]\n", pname);
    fprintf(stderr, "           -p num   Create num threads\n");
    fprintf(stderr, "                    Default is %d.\n", DEFAULT_NUM_THREADS);

}  /* end of usage() */



/*
** Call gettimeofday() and convert result into a double
** with usec resolution
*/
static double
dclock(void)
{

int rc;
struct timeval tv;
unsigned long result;


    rc= gettimeofday(&tv, (struct timezone *) 0);
    if (rc != 0)   {
        fprintf(stderr, "ERROR gettimeofday() failed: %s\n", strerror(errno));
        return 0.0;
    }

    result= tv.tv_sec * 1000000 + tv.tv_usec;
    return (double)result;

}  /* end of dclock() */
