#include "order.h"
#include <algorithm>

void
order_support(itemsetvector & res)
{
	std::vector<std::pair<uint32_t, itemset *> > ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++)
		ord[i] = std::make_pair(res[i]->support(), res[i]);
	std::sort(ord.begin(), ord.end());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
}


void
order_robust_free(itemsetvector & res)
{
	std::vector<std::pair<long double, itemset *> > ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++)
		ord[i] = std::make_pair(res[i]->free(), res[i]);
	std::sort(ord.begin(), ord.end());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
}

void
order_robust_ts(itemsetvector & res)
{
	std::vector<std::pair<long double, itemset *> > ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++)
		ord[i] = std::make_pair(res[i]->ts(), res[i]);
	std::sort(ord.begin(), ord.end());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
}

void
order_robust_ndi(itemsetvector & res)
{
	std::vector<std::pair<long double, itemset *> > ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++)
		ord[i] = std::make_pair(res[i]->ndi(), res[i]);
	std::sort(ord.begin(), ord.end());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
}

typedef std::pair<uintvector, itemset *> litemset;

struct lexicograph {
	bool
	operator () (const litemset & a, const litemset & b)
	{
		for (uint32_t i = 0; i < a.first.size() && i < b.first.size(); i++)
			if (a.first[i] < b.first[i])
				return true;
			else if (a.first[i] > b.first[i])
				return false;
		return a.first.size() > b.first.size();
	}
};


void
order_mv_free(itemsetvector & res)
{
	std::vector<litemset> ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++) {
		uintvector cont = res[i]->contingency();
		ord[i].first.reserve(res[i]->items().size());
		for (uint32_t j = 1; j < cont.size(); j *= 2)
			ord[i].first.push_back(cont[cont.size() - j - 1]);
		std::sort(ord[i].first.begin(), ord[i].first.end());

		ord[i].second = res[i];
	}
	std::sort(ord.begin(), ord.end(), lexicograph());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;

}

void order_mv_ts(itemsetvector & res)
{
	std::vector<litemset> ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++) {
		ord[i].first = res[i]->contingency();
		std::sort(ord[i].first.begin(), ord[i].first.end());

		ord[i].second = res[i];
	}
	std::sort(ord.begin(), ord.end(), lexicograph());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
}

typedef std::vector<std::pair<uint32_t, int32_t> > sparsevector;

sparsevector
operator + (const sparsevector & a, const sparsevector & b)
{
	sparsevector c;
	c.reserve(a.size() + b.size());

	uint32_t i = 0;
	uint32_t j = 0;
	while (i < a.size() && j < b.size()) {
		if (a[i].first < b[j].first)
			c.push_back(a[i++]);
		else if (a[i].first > b[j].first)
			c.push_back(b[j++]);
		else {
			if (a[i].second != -b[j].second)
				c.push_back(std::make_pair(a[i].first, a[i].second + b[j].second));
			i++; j++;
		}
	}
	for (; i < a.size(); i++) c.push_back(a[i]);
	for (; j < b.size(); j++) c.push_back(b[j]);

	return c;
}

sparsevector
operator - (const sparsevector & a, const sparsevector & b)
{
	sparsevector c;
	c.reserve(a.size() + b.size());

	uint32_t i = 0;
	uint32_t j = 0;
	while (i < a.size() && j < b.size()) {
		if (a[i].first < b[j].first)
			c.push_back(a[i++]);
		else if (a[i].first > b[j].first) {
			c.push_back(std::make_pair(b[j].first, -b[j].second));
			j++;
		}	
		else {
			if (a[i].second != b[j].second)
				c.push_back(std::make_pair(a[i].first, a[i].second - b[j].second));
			i++; j++;
		}
	}
	for (; i < a.size(); i++) c.push_back(a[i]);
	for (; j < b.size(); j++) c.push_back(std::make_pair(b[j].first, -b[j].second));

	return c;
}

sparsevector
operator << (const sparsevector & a, uint32_t s)
{
	sparsevector c(a);
	for (uint32_t i = 0; i < a.size(); i++)
		c[i].first += s;
	return c;
}

void
print(const sparsevector & p)
{
	for (uint32_t j = 0; j < p.size(); j++)
		printf("(%i %u) ", p[j].second, p[j].first);
	printf("\n");
}


static uint32_t
bitcount(uint32_t i)
{
	i = i - ((i >> 1) & 0x55555555);
	i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
	return ((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) >> 24;
}


sparsevector *
expandndi(const uintvector & cont)
{
	sparsevector a(1), b(1), c(1);
	a[0] = b[0] = c[0] = std::make_pair(0, 1);

	for (uint32_t i = 0; i < cont.size(); i++) {
		if (bitcount(i) % 2)
			a = a - (a << cont[i]);
		else
			b = b - (b << cont[i]);
		c = c - (c << cont[i]);
		//printf("%d: ", cont[i]);
		//print(c);
	}
	//print(a);
	//print(b);
	//print(c);
	return new sparsevector(a + b - c);
}

typedef std::pair<sparsevector *, itemset *> sitemset;

struct splexicograph {
	bool
	operator () (const sitemset & a, const sitemset & b)
	{
		// assumes that there are no zero entries in sparsevectors
		const sparsevector & p = *a.first;
		const sparsevector & q = *b.first;
		uint32_t i = 0;

		for (i = 0; i < p.size() && i < q.size(); i++) {
			if (p[i].first < q[i].first)			
				return p[i].second < 0;
			if (p[i].first > q[i].first)			
				return q[i].second > 0;
			if (p[i].first == q[i].first && p[i].second != q[i].second)			
				return p[i].second < q[i].second;
		}

		if (i < p.size()) return p[i].second < 0;
		if (i < q.size()) return q[i].second > 0;

		return false;
	}
};

void order_mv_ndi(itemsetvector & res)
{
	std::vector<sitemset> ord(res.size());
	for (uint32_t i = 0; i < res.size(); i++) {
		ord[i].second = res[i];
		ord[i].first = expandndi(res[i]->contingency());
		/*res[i]->print(stdout, "ITEM: %i");
		uintvector q = res[i]->contingency();
		for (uint32_t j = 0; j < q.size(); j++) printf("%u ", q[j]);
		printf("\n");
		print(*ord[i].first);*/
	}
	std::sort(ord.begin(), ord.end(), splexicograph());
	for (uint32_t i = 0; i < res.size(); i++)
		res[i] = ord[res.size() - i - 1].second;
	for (uint32_t i = 0; i < res.size(); i++)
		delete ord[i].first;
}
