#include "tile.h"
#include <assert.h>
#include <sstream>

#include "mathutil.h"

#include <stdio.h>

void
tile::print(FILE *f)
{
	int32_t pid = -1;
	if (parent) pid = parent->id;
	fprintf(f, "%d %d %d %d %d %d %d %d %f\n", id, tstart, tend, istart, iend, ones, size, pid, gain);
}

void
tileset::print(FILE *f)
{
	for (tilelist::iterator it = m_tiles.begin(); it != m_tiles.end(); ++it)
		(*it)->print(f);
}

tileset::tileset(const dataset & d, bool dis) : m_data(d), m_occupy(d.cnt()*d.dim()), m_scanner(d.cnt()), m_disjoint(dis)
{
	tile *t = new tile;
	t->istart = 0;
	t->iend = d.dim();
	t->tstart = 0;
	t->tend = d.cnt();
	t->ones = d.ones();
	t->size = d.cnt() * d.dim();
	t->parent = 0;
	t->id = 0;
	t->gain = 0;

	m_tiles.push_back(t);
}

tileset::~tileset()
{
	for (tilelist::iterator it = m_tiles.begin(); it != m_tiles.end(); ++it)
		delete *it;
}

void
tileset::add(tile *t, tile *p)
{
	t->parent = p;
	t->id = m_tiles.size();
	p->ones -= t->ones;
	p->size -= t->size;
	m_tiles.push_back(t);

	for (uint32_t i = t->istart; i < t->iend; i++)
		for (uint32_t j = t->tstart; j < t->tend; j++)
			if (occupy(j, i) == p->id) occupy(j, i) = t->id;
}

scanner::result
tileset::optimize(const uintvector & p, const uintvector & n, uint32_t start, uint32_t end, uint32_t baseones, uint32_t basesize) const
{
	//printf("boo\n");
	scanner::result pos = m_scanner.optimize(p, n, start, end, baseones, basesize);
	//printf("%f %d %d\n", pos.score, pos.start, pos.end);
	scanner::result neg = m_scanner.optimize(n, p, start, end, basesize - baseones, basesize);
	//printf("bee\n");
	neg.ones = neg.size - neg.ones;
	return pos.score < neg.score ? pos : neg;
}


scanner::result
tileset::optimize(const uintvector & p, const uintvector & n, uint32_t width, uint32_t baseones, uint32_t basesize) const
{
	if (!m_disjoint)
		return optimize(p, n, 0, p.size(), baseones, basesize);

	scanner::result best = {0, 0, 0, 0, std::numeric_limits<double>::max()};

	uint32_t start = 0;
	uint32_t end = p.size();
	while (true) {
		while (start < end && p[start] + n[start] < width) start++;

		if (start == end) break;

		uint32_t i = start;
		while (i < end && p[i] + n[i] == width) i++;

		scanner::result cand = optimize(p, n, start, i, baseones, basesize);
		if (cand.score < best.score) best = cand;

		start = i;
	}

	return best;
}

tile *
tileset::find(const tile *cur) const
{
	uint32_t K = cur->iend - cur->istart;

	std::vector<const dataset::entry *> ones(K);

	printf("%sSearching for tile in (%d - %d, %d - %d)\n", m_indent.c_str(), cur->tstart, cur->tend - 1, cur->istart, cur->iend - 1);

	tile *t = new tile;
	t->size = 0;
	t->ones = 0;
	double best = score(cur->ones, cur->size) - tileoverhead(cur);


	for (uint32_t i = 0; i < K; i++) {
		const dataset::entry * & e = ones[i];
		for (e = m_data.first(i + cur->istart); e && e->tid() < cur->tstart; e = e->next());
	}

	uintvector cones(K);
	uintvector csize(K);

	
	for (uint32_t a = 0; a < K; a++) {
		const dataset::entry *e = ones[a];
		for (uint32_t i = 0; i < cur->tend - cur->tstart; i++) {
			if (occupy(cur->tstart + i, cur->istart + a) == cur->id) {
				while (e && e->tid() < cur->tstart + i) e = e->next();
				if (e && e->tid() == cur->tstart + i) {
						cones[a]++;
						e = e->next();
				}
				csize[a]++;
			}
		}
	}

	for (uint32_t a = 1; a < K; a++) {
		cones[a] += cones[a - 1];
		csize[a] += csize[a - 1];
	}

	std::vector<doublevector> bounds(K);
	


	for (int32_t a = K - 1; a >= 0; a--) {

		uint32_t lones = a > 0 ? cones[a - 1] : 0;
		uint32_t lsize = a > 0 ? csize[a - 1] : 0;
		double lcost = score(lones, lsize);

		uintvector p(cur->tend - cur->tstart);
		uintvector n(cur->tend - cur->tstart);
		
		doublevector & bound = bounds[a];
		doublevector estimate(K);



		for (uint32_t b = a; b < K; b++) {
			// Check the bound

			bool bounded = false;
			for (uint32_t i = a; i < b && i < K - 1; i++) {
				double cost = b - i < bounds[i + 1].size() ? bounds[i + 1][b - i - 1] : bounds[i + 1].back();

				if (best < cost + estimate[i - a] + lcost) { 
					bounded = true;
					bound.push_back(estimate[i - a] + cost);	
					break;
				}
			}
			if (bounded) break;

			// Update p and n
			const dataset::entry *e = ones[b];

			for (uint32_t i = 0; i < p.size(); i++) {
				if (occupy(cur->tstart + i, cur->istart + b) == cur->id) {
					while (e && e->tid() < cur->tstart + i) e = e->next();
					if (e && e->tid() == cur->tstart + i) {
						p[i]++;
						e = e->next();
					}
					else
						n[i]++;
				}
			}

			//printf("foo\n");

			// Compute the value
			scanner::result cand = optimize(p, n, b - a + 1, cur->ones, cur->size);
			//printf("%d %d %f  \n", a, b, cand.score);
			if (cand.score < best) {
				t->ones = cand.ones;
				t->size = cand.size;
				t->tstart = cand.start + cur->tstart;
				t->tend = cand.end + cur->tstart;
				t->istart = a + cur->istart;
				t->iend = b + cur->istart + 1;
				best = cand.score;
			}
			//double c = cand.score;


			//printf("done\n");

			// Computing branch & bound
			uint32_t rones = cones.back() - cones[b];
			uint32_t rsize = csize.back() - csize[b];

			uint32_t pones = cones[b] - lones;
			uint32_t psize = csize[b] - lsize;

			//m_scanner.debug = a == 55 && b == 63;
			cand = optimize(p, n, b - a + 1, pones, psize);
				

			//m_scanner.debug = false;
			estimate[b - a] = cand.score;
			bound.push_back(cand.score + score(rones, rsize));

			/*
			if (a == 54 && b == 54)
				printf("%d %f %d %d %d %d %d %d\n", a, cand.score, cand.start, cand.end, cand.ones, cand.size, pones, psize);
			if (a == 55 && b == 63)
				printf("%d %f %d %d %d %d %d %d\n", a, cand.score, cand.start, cand.end, cand.ones, cand.size, pones, psize);

			assert(cand.score + score(rones, rsize) + lcost <= c);

			if (a == 83 && b == 114) {
				printf("%f %f", cand.score, bound.back());
			}

			
			for (uint32_t i = a; i < b && i < K - 1; i++) {
				double cost = bounds[i + 1][b - i - 1];
				if (estimate[i - a] + cost - 0.01 > cand.score) {
					printf("%d %d %d %f %f %f %f %d %d %d %d\n", a, i, b, cost + estimate[i - a] + lcost, cost, estimate[i - a], cand.score, cand.start, cand.end, cand.ones, cand.size);
				}
				assert(estimate[i - a] + cost <= cand.score + 0.01);
			}
			*/

			/*
			for (uint32_t i = a; i < b && i < K - 1; i++) {
				double cost = bounds[i + 1][b - i - 1];
				if (estimate[i - a] + cost + lcost > c) {
					printf("%d %d %d %f %f %f %f %d %d %d %d\n", a, i, b, cost + estimate[i - a] + lcost, cost, estimate[i - a], cand.score, cand.start, cand.end, cand.ones, cand.size);
				}
				assert(estimate[i - a] + cost + lcost <= c);
			}
			*/

		}

		/*
		if (a == 83) {
			for (uint32_t i = 0; i < bound.size(); i++)
				printf("(%d %f) ", a + i, bound[i]);
			printf("\n");
		}
		*/

		for (uint32_t i = bound.size() - 1; i > 0; i--)
			bound[i - 1] = std::min(bound[i - 1], bound[i]);

		if (t->size > 0)
			printf("%s%d %d %f   \r",m_indent.c_str(),  a, a + bound.size(),  best);
		fflush(stdout);
	}


	return t;
}

double
tileset::optimize(tile *t)
{
	m_indent += "    ";
	double gain = 0;

	while (true) {
		tile *cand = find(t);
		double base = score(t->ones, t->size);
		double c = tilecost(cand, t);
		if (c < base) {

			printf("%sAdding tile (%d - %d, %d - %d) with cost %f (base %f) %s \n", m_indent.c_str(),
				cand->tstart, cand->tend - 1, cand->istart, cand->iend - 1, c, base,
				cand->freq() > t->freq() ? "pos" : "neg");
			gain += base - c;
			cand->gain = base - c;
			add(cand, t);
			gain += optimize(cand);
		}
		else {
			delete cand;
			break;
		}
	}

	m_indent.erase(m_indent.begin(), m_indent.begin() + 4);

	return gain;
}


double
tileset::tilecost(const tile *t, const tile *p) const
{
	return score(t->ones, t->size, p->ones, p->size) + tileoverhead(p);
}

double
tileset::tileoverhead(const tile *p) const
{
	return 5*(lg2(p->tend - p->tstart) + lg2(p->iend - p->istart)) + 2;
}
