#include "graph.h"
#include "border.h"
#include "segment.h"
#include "order.h"
#include <getopt.h>
#include <math.h>
#include <sstream>
#include <map>

#include <sys/times.h>
#include <unistd.h>

struct sortcmp {
	doublevector pr;
	bool operator () (int i, int j) {return pr[i] > pr[j];}
};

void
setpr(graph & g, uintvector & seed, double alpha, doublevector & pr, char w)
{
	// compute page rank
	g.prwr(pr, alpha, seed);
	switch (w) { 
		case 'n':
			g.set_weights_normalized(pr);
			break;
		case 'm':
			g.set_weights_min(pr);
			break;
		case 's':
			g.set_weights_sum(pr);
			break;
	}
}

void
permuteorder(graph & g, uintvector & seed)
{
	uintvector perm = order(g, seed);

	g.permute(perm);
	g.update_pre();
	for (uint32_t i = 0; i < seed.size(); i++) seed[i] = perm[seed[i]];
}

void
permuterank(graph & g, uintvector & seed, doublevector & pr)
{
	// compute the permutation
	for (uint32_t i = 0; i < seed.size(); i++) pr[seed[i]] = std::numeric_limits<double>::max();

	std::multimap<double, uint32_t> r;

	for (uint32_t i = 0; i < g.size(); i++) r.insert(std::make_pair(pr[i], i));
	
	uintvector invperm(g.size());
	uint32_t k = 0;
	for (std::multimap<double, uint32_t>::reverse_iterator it = r.rbegin(); it != r.rend(); it++)
		invperm[it->second] = k++;


	// apply permutation
	g.permute(invperm);
	g.update_pre();
	for (uint32_t i = 0; i < seed.size(); i++) seed[i] = invperm[seed[i]];
}

void
permutedegree(graph & g, uintvector & seed)
{
	doublevector pr(g.size());
	// compute the permutation
	for (uint32_t i = 0; i < g.size(); i++) pr[i] = g.deg(i);
	for (uint32_t i = 0; i < seed.size(); i++) pr[seed[i]] = std::numeric_limits<double>::max();

	std::multimap<double, uint32_t> r;

	for (uint32_t i = 0; i < g.size(); i++) r.insert(std::make_pair(pr[i], i));
	
	uintvector invperm(g.size());
	uint32_t k = 0;
	for (std::multimap<double, uint32_t>::reverse_iterator it = r.rbegin(); it != r.rend(); it++)
		invperm[it->second] = k++;


	// apply permutation
	g.permute(invperm);
	g.update_pre();
	for (uint32_t i = 0; i < seed.size(); i++) seed[i] = invperm[seed[i]];
}

/*
void
permute(graph & g, uintvector & seed, doublevector & pr)
{
	// compute the permutation
	for (uint32_t i = 0; i < seed.size(); i++) pr[seed[i]] = std::numeric_limits<double>::max();
	sortcmp cmp;
	cmp.pr.swap(pr);

	uintvector perm(g.size());
	for (uint32_t i = 0; i < g.size(); i++) perm[i] = i;
	printf("Sort\n");
	std::make_heap(perm.begin(), perm.end(), cmp);
	std::sort_heap(perm.begin(), perm.end(), cmp);
	printf("Sort\n");
	uintvector invperm(g.size());
	for (uint32_t i = 0; i < g.size(); i++) invperm[perm[i]] = i;

	// apply permutation
	g.permute(invperm);
	g.update_pre();
	for (uint32_t i = 0; i < seed.size(); i++) seed[i] = invperm[seed[i]];

	//for (uint32_t i = 0; i < perm.size(); i++) printf("%d ", invperm[i]);
	//for (uint32_t i = 0; i < perm.size(); i++) printf("%.2f ", g.preweight(i));
	//printf("\n");
}
*/




void
find_borders(const graph & g, uint32_t seed_cnt, blockvector & borders)
{
	doublevector w(g.size() - seed_cnt);
	doublevector v(g.size() - seed_cnt);
	uintvector sizes(g.size() - seed_cnt);

	for (uint32_t i = 0; i < w.size(); i++) w[i] = g.preweight(i + seed_cnt);
	for (uint32_t i = 0; i < sizes.size(); i++) sizes[i] = i + seed_cnt;
	for (uint32_t i = 0; i < v.size(); i++) v[i] = g.prevar(i + seed_cnt);

	border(w, sizes, v, borders);
}

void
preweightrank(graph & g, const uintvector & boundaries, doublevector & rank)
{
	rank.resize(g.size());

	uint32_t k = boundaries.size() - 1;
	for (uint32_t i = 0; i < g.size(); i++) {
		if (g.size() - boundaries[k] == i) {
			k--;
		}
		rank[i] = k + g.preweight(i) / std::max(g.predeg(i), 1.0);
		//printf("%f\n", rank[i]);
	}
}

void
print_segmentation(const graph & g, const uintvector & boundaries, FILE *out)
{
	uint32_t k = boundaries.size() - 1;
	for (uint32_t i = 0; i < g.size(); i++) {
		if (g.size() - boundaries[k] == i) {
			fprintf(out, "\n");
			k--;
		}
		//printf("%s (%.3f) ", g.label(i).c_str(), g.preweight(i) / i);
		//printf("%s (%.3f) ", g.label(i).c_str(), g.preweight(i) / g.predeg(i));
		fprintf(out, "%s ", g.label(i).c_str());
	}
	fprintf(out, "\n");
}



uintvector
parseseed(const char *str, uint32_t minid) {
	uintvector seed;

	std::stringstream ss(str, std::stringstream::in);
	uint32_t a;
	while (ss >> a)
		seed.push_back(a - minid);
	return seed;
}

void
print_stats(FILE *out, const segment & s, const graph & g, const char *name, const tms & tmstart, const tms & tmend, const blockvector & b)
{
	uint32_t c = 0;
	for (uint32_t i = 0; i < g.size(); i++)
		c += g.nbhdsize(i);

	fprintf(out, "%s %d %d %lf %lf %d %d", name, uint32_t(g.size()), c / 2, double(tmend.tms_utime - tmstart.tms_utime) / sysconf(_SC_CLK_TCK), s.score(0), uint32_t(b.size()), s.segcnt());

	for (uint32_t i = 1; i < s.segcnt(); i++)
		fprintf(out, " %f", s.score(i) / s.score(0));
		
	fprintf(out, "\n");
}



int
main(int argc, char **argv)
{
	static struct option longopts[] = {
		{"seed",            required_argument,  NULL, 's'},
		{"bound",           required_argument,  NULL, 'b'},
		{"out",             required_argument,  NULL, 'o'},
		{"in",              required_argument,  NULL, 'i'},
		{"k",               required_argument,  NULL, 'k'},
		{"help",            no_argument,        NULL, 'h'},
		{ NULL,             0,                  NULL,  0 }
	};

	char *inname = NULL;
	char *outname = NULL;
	uint32_t segcnt = 2;
	uintvector seed;
	float alpha = 0.1;
	float bound = 0;
	char *seedstr;
	bool weighted = false;
	bool maxseed = false;

	char rankmethod = 'o';
	char weightmethod = 'n';



	int ch;
	while ((ch = getopt_long(argc, argv, "ho:i:k:s:a:b:wdr:v:", longopts, NULL)) != -1) {
		switch (ch) {
			case 'h':
				printf("Usage: %s -i <input file> -o <output file> [-k segnum] [-dw] [-s seeds] [-v weight method] [-r rankmethod] [-a <alpha>] [options]\n", argv[0]);
				printf("  -h    print this help\n");
				printf("  -i    input file\n");
				printf("  -o    output file\n");
				printf("  -s    seeds (space separated)\n");
				printf("  -a    restart for pagerank\n");
				printf("  -w    graph contains weights\n");
				printf("  -d    pick seed to be the node with the largest degree\n");
				printf("  -v    weight method (s = sum, n = normalized (default), m = minimum)\n");
				printf("  -r    ranking method (o = order (default), r = pagerank, d = hop)\n");
				printf("  -k    number of segments (default: 2)\n");

				return 0;
				break;
			case 'd':
				maxseed = true;
				break;
			case 'w':
				weighted = true;
				break;
			case 'a':
				alpha = atof(optarg);
				break;
			case 'b':
				bound = atof(optarg);
				break;
			case 's':
				seedstr = optarg;
				break;
			case 'k':
				segcnt = atoi(optarg);
				break;
			case 'i':
				inname = optarg;
				break;
			case 'v':
				weightmethod = optarg[0];
				break;
			case 'r':
				rankmethod = optarg[0];
				break;
			case 'o':
				outname = optarg;
				break;
		}
	}

	if (inname == NULL) { 
		printf("Missing input file\n");
		return 1;
	}

	if (seedstr == NULL && !maxseed) { 
		printf("Missing seed\n");
		return 1;
	}

	tms tmstart, tmend;

	times(&tmstart);


	FILE *f = fopen(inname, "r");
	graph g;
	if (weighted)
		g.readweighted(f);
	else
		g.read(f);
	fclose(f);

	if (maxseed) {
		seed.resize(1);
		seed[0] = 0;
		for (uint32_t i = 0; i < g.size(); i++)
			if (g.nbhdsize(seed[0]) < g.nbhdsize(i))
				seed[0] = i;
	}
	else
		seed = parseseed(seedstr, g.minid());

	if (seed.size() == 0) { 
		printf("Invalid seed\n");
		return 1;
	}


	doublevector pr;
	setpr(g, seed, alpha, pr, weightmethod);
	fprintf(stderr, "Permuting\n");


	switch(rankmethod) {
		case 'r':
			permuterank(g, seed, pr);
			break;
		case 'o':
			permuteorder(g, seed);
			break;
		case 'd':
			permutedegree(g, seed);
			break;
	}

	uintvector boundaries;

	blockvector borders;
	find_borders(g, seed.size(), borders);

	fprintf(stderr, "Segmenting %d\n", uint32_t(borders.size()));

	segment seg(borders, segcnt, pow(1 + bound, 1.0 / segcnt) - 1);
	seg.run();
	seg.extract(boundaries);


	FILE *out = stdout;
	if (outname != NULL) out = fopen(outname, "w");

	print_segmentation(g, boundaries, out);

	if (outname != NULL) fclose(out);

	times(&tmend);

	print_stats(stdout, seg, g, inname, tmstart, tmend, borders);

	return 0;
}

