/* **************************************************************************
 * threads1.cc - Demo of a C++ Thread Class / [KK 2008-08-26] 
 * Compile using 	c++ -Wall -o threads1 threads1.cc -lpthread
 * then run with	./threads1
 * **************************************************************************/

/* Part 1: Thread class */
#include <errno.h>
#include <pthread.h>
#include <string.h>

#include <iostream>
#include <string>
#include <map>

using namespace std;

class Thread {
public:
    virtual ~Thread();				// dtor

    virtual void act() = 0;			// Pure virtual act()

    pthread_t run(bool joinable = true);	// The starter

    void lock(void *target);			// Synchronization
    void unlock(void *target);

private:
    static void *start(Thread *obj);
    void plock(void *target);

    static bool initialized;			// Mutex initialized yet?
    static std::map<void *,  pthread_mutex_t> 
        s_lock;					// The mutexes database
    typedef std::map<void *,  pthread_mutex_t>::iterator mapIterator;
};

bool Thread::initialized = false;
std::map<void *,  pthread_mutex_t> Thread::s_lock;

Thread::~Thread() {
}

void Thread::lock(void *target) {
    plock(&s_lock);
    plock(target);
    unlock(&s_lock);
}

void Thread::plock(void *target) {
    mapIterator iter = s_lock.find(target);
    if (iter == s_lock.end()) {
	// No such lock yet, create the mutex
        if (int res = pthread_mutex_init(&s_lock[target], 0))
            throw static_cast<string>("Failed to initialize static mutex: ") +
	    strerror(res);
    }
    
    if (int res = pthread_mutex_lock(&s_lock[target]))
        throw static_cast<string>("Failed to obtain mutex lock: ") +
	strerror(res);
}

void Thread::unlock(void *target) {
    if (int res = pthread_mutex_unlock(&s_lock[target]))
        throw static_cast<string>("Failed to release mutex lock: ") +
	strerror(res);
}

pthread_t Thread::run(bool joinable) {
    pthread_attr_t attr;
    pthread_t th;
    int res;

    if (pthread_attr_init (&attr))
	throw ("Cannot initialize thread attributes");
    if (joinable) {
	if (pthread_attr_setdetachstate (&attr, PTHREAD_CREATE_JOINABLE))
	    throw ("Cannot set thread state as joinable");
    } else {
	if (pthread_attr_setdetachstate (&attr, PTHREAD_CREATE_DETACHED))
	    throw ("Cannot set thread state as detached");
    }
    for (int i = 0; i < 3; i++) {
	res = pthread_create(&th, &attr,
			     reinterpret_cast<void *(*)(void *)>(start),
			     this);
	if (!res) {
	    pthread_attr_destroy(&attr);
	    return (th);
	}
	if (res == EAGAIN) {
	    cout << "Failed to start thread: " << strerror(res) <<
		", retrying\n";
	    sleep(1);
	    continue;
	}
	pthread_attr_destroy (&attr);
	throw (static_cast<string>("Failed to start thread: ") +
	       strerror(res));
    }
    
    throw ("Failed to start thread: "
	   "Resources unavailable after 3 tries, giving up");
}

void *Thread::start(Thread *t) {
    try {
	t->act();
    } catch (string s) {
	cerr << s << "\n";
    } catch (...) {
	cerr << "Thread threw an exception\n";
    }

    delete (t);
    return 0;
}


/* Part 2: Example of a Thread-derived class */
class MyThread: public Thread {
public:
    void act();
};

void MyThread::act() {
    for (int i = 0; i < 10; i++) {
	lock(&cout);
	cout << "Hello World from thread " << hex << pthread_self()
	     << ", counter is " << i << "\n";
	unlock(&cout);
	sleep (1);
    }
}

/* Part 3: A tester. Function test1() shows non-joinable threads. 
 * Function test2() shows joinable ones. And of course main() runs it.
 */
void test1() {
    cout << "Starting two threads,\n"
	 << "and sleeping for 15 seconds to allow threads to finish.\n";
    
    // Instantiate two thread objects and fire them up.
    (new MyThread())->run(false);
    (new MyThread())->run(false);

    // Wait for the threads to finish.
    sleep (15);
}

void test2() {
    cout << "Starting two threads,\n"
	 << "and waiting for them to finish.\n";
    
    // Instantiate two thread objects.
    MyThread
	*a = new MyThread(),
	*b = new MyThread();

    // Start 'em up.
    pthread_t id_a = a->run();
    pthread_t id_b = b->run();

    // Wait for the threads to finish.
    pthread_join (id_a, 0);
    pthread_join (id_b, 0);
}

int main() {
    test1();
    test2();
    return (0);
}

