#include <pthread.h>
#include <iostream>
#include <sstream>

class Mutex {
public:
    Mutex(): _initialized(false) 	{ }
    void lock();
    void unlock();
private:
    pthread_mutex_t _mutex;
    bool _initialized;
};

void Mutex::lock() {
    if (!_initialized) {
	_initialized = true;
	if (pthread_mutex_init(&_mutex, 0))
	    throw "Failed to initialize mutex";
    }
    if (pthread_mutex_lock(&_mutex))
	throw "Failed to lock mutex";
}    

void Mutex::unlock() {
    if (pthread_mutex_unlock(&_mutex))
	throw "Failed to unlock mutex";
}    

class MutexNode {
public:
    MutexNode(void *o);
    ~MutexNode();
    
    void obj(void *o)			{ _obj = o; }
    void *obj() const  			{ return _obj; }

    void left(MutexNode *l)		{ _left = l; }
    MutexNode *left() const   		{ return _left; }

    void right(MutexNode *l)		{ _right = l; }
    MutexNode *right() const   		{ return _right; }

    void lock()				{ _mutex.lock(); }
    void unlock()			{ _mutex.unlock(); }
    
private:
    Mutex _mutex;
    void *_obj;
    MutexNode *_left, *_right;
};

MutexNode::MutexNode(void *o): _mutex(), _obj(o) {
    _left = _right = 0;
}

MutexNode::~MutexNode() {
    delete _left;
    delete _right;
}


class MutexTree {
public:
    MutexTree();
    ~MutexTree();
    
    void lock(void *o);
    void unlock(void *o);
    
private:
    void locktree()		{ _treelock.lock(); };
    void unlocktree()		{ _treelock.unlock(); }
    MutexNode *nodelock(void *o, MutexNode *start);
    void nodeunlock(void *o, MutexNode *start);
    
    MutexNode *_root;
    Mutex _treelock;
};

MutexTree::MutexTree(): _treelock() {
    _root = 0;
}

MutexTree::~MutexTree() {
    delete _root;
}

void MutexTree::lock(void *o) {
    locktree();
    _root = nodelock(o, _root);
    unlocktree();
}

void MutexTree::unlock(void *o) {
    nodeunlock(o, _root);
}

MutexNode *MutexTree::nodelock(void *o, MutexNode *start) {
    if (!start) {
	start = new MutexNode(o);
	start->lock();
    } else if (start->obj() == o)
	start->lock();
    else if (start->obj() < o)
	start->left(nodelock(o, start->left()));
    else
	start->right(nodelock(o, start->right()));
    
    return start;
}

void MutexTree::nodeunlock(void *o, MutexNode *start) {
    if (!start)
	return;
    
    if (start->obj() == o)
	start->unlock();
    else if (start->obj() < o)
	nodeunlock(o, start->left());
    else
	nodeunlock(o, start->right());
}

static MutexTree mt;

void mutex_lock(void *obj) {
    mt.lock(obj);
}

void mutex_unlock(void *obj) {
    mt.unlock(obj);
}


void *test(void *data) {
    for (int i = 1; i <= 10; i++) {
	if (i & 1) {
	    mutex_lock(&std::cout);    
	    std::cout << pthread_self() << " testing cout, loop " << i << '\n';
	    mutex_unlock(&std::cout);
	} else {
	    mutex_lock(&std::cerr);
	    std::cerr << pthread_self() << " testing cerr, loop " << i << '\n';
	    mutex_unlock(&std::cerr);
	}
    }
    return 0;
}

int main() {
    try {
	pthread_t th, threads[10];
	for (unsigned i = 0; i < sizeof(threads) / sizeof(pthread_t); i++) {
	    pthread_create(&th, 0, test, 0);
	    threads[i] = th;
	}
	for (unsigned i = 0; i < sizeof(threads) / sizeof(pthread_t); i++)
	    pthread_join(threads[i], 0);
	std::cout << "Done!\n";
	return 0;
    } catch (char const *s) {
	std::cerr << s << "\n";
	return 1;
    }
}

