// Copyright 2004,2006 Jouni K. Seppnen          -*- coding: iso-8859-1 -*-
// Distributed under the Boost Software License, Version 1.0.
// See accompanying file LICENSE.

#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>

#pragma implementation "model.h"
#include "model.h"
#include "matrices.h"

namespace Tilings {

  Model::Model(const Matrices::Matrix<double>& input)
    : data(input)
  {
    using namespace Matrices;
    // Clever trick to count the sums using matrix multiplication:
    // sums = triu_ones(data.dim1()) * data * tril_ones(data.dim2());
    // However, the loop is faster and allows us easily to add
    // an extra row and column that are useful in ones_within():
    sums = Matrices::Matrix<double>(data.dim1()+1,data.dim2()+1);
    sums = 0.0;
    for (int j = data.dim2()-1; j >= 0; j--)
      for (int i = data.dim1()-1; i >= 0; i--)
	sums[i][j] =
	  data[i][j]			// data plus
	  + sums[i+1][j] + sums[i][j+1] // inclusion-exclusion
	  - sums[i+1][j+1];		// of two submatrices

    root = new Node(Tile(0, data.dim1()-1, 0, data.dim2()-1),
		    (long int)sums[0][0]);
    tiles = std::vector<Tile*>();
    tiles.push_back((Tile*)root);
    //std::cout << data/* << sums*/;
  }

  Model::Model(const Model& other)
  {
//     std::cerr << "Model(Model): " << &other << " -> " << this << std::endl;
    if (this == &other)
      return;
    data = other.data;
    sums = other.sums;
    root = new Node(*other.root);
    tiles = std::vector<Tile*>();
    root->into_tiles(tiles);
  }

  Model& Model::operator=(const Model& other)
  {
//     std::cerr << "Model::operator=(): " << &other 
// 	      << " -> " << this << std::endl;
    if (this == &other)
      return *this;
    data = other.data;
    sums = other.sums;
    root = new Node(*other.root);
    tiles = std::vector<Tile*>();
    root->into_tiles(tiles);
    return *this;
  }

  Model::~Model()
  {
    //std::cerr << "Model::~Model() " << this << std::endl;
    delete root;
  }

  void Model::Node::into_tiles(std::vector<Tile*>& a)
  {
//     std::cerr << "into_tiles: " << *this << std::endl;
    a.push_back((Tile*)this);
    for (unsigned int i = 0; i < children.size(); i++)
      children[i]->into_tiles(a);
  }

  long int Model::Node::next_id = 0;

  Model::Node::Node(Tile tile, long int ones_)
    : Tile(tile), id(next_id++), ones(ones_), my_size(size()), my_ones(ones_),
      children(std::vector<Node*>())
  {
    /* Model::add_tile does the rest */
  }

  Model::Node::Node(const Model::Node& other)
    : Tile(other), id(other.id), ones(other.ones), 
      my_size(other.my_size), my_ones(other.my_ones),
      children(std::vector<Node*>())
  {
    for (unsigned int i = 0; i < other.children.size(); i++)
      children.push_back(new Node(*other.children[i]));
//     std::cerr << "Model::Node::Node(const Node&): "
// 	      << &other << " -> " << this << *this << std::endl;
  }

  Model::Node&
  Model::Node::operator=(const Model::Node& other)
  {
//     std::cerr << "Model::Node::operator=(): " << &other 
// 	      << " -> " << this 
// 	      << *(Tile*)&other
// 	      << std::endl;
    if (this == &other)
      return *this;
    id = other.id;
    memcpy(edge, other.edge, 4*sizeof(long int));
    ones = other.ones;
    my_size = other.my_size;
    my_ones = other.my_ones;
    children = std::vector<Node*>();
    for (unsigned i = 0; i < other.children.size(); i++)
      children.push_back(new Node(*other.children[i]));
    return *this;
  }

  Model::Node::~Node()
  {
    //std::cerr << "Model::Node::~Node() " << this << std::endl;
    for (unsigned i = 0; i < children.size(); i++)
      delete children[i];
  }

  void 
  Model::Node::print(std::ostream& s, long int parent) const
  {
    s << id << ' '
      << edge[TOP] << ' '
      << edge[BOTTOM] << ' '
      << edge[LEFT] << ' '
      << edge[RIGHT] << ' '
      << parent << ' '
      << size() << ' '
      << ones << ' '
      << my_size << ' '
      << my_ones << std::endl;
    for (unsigned i = 0; i < children.size(); i++)
      children[i]->print(s, id);
  }

  std::ostream& operator<<(std::ostream& s, const Model& m)
  {
    s << "% number xmin xmax ymin ymax "
      << "parent size ones own-size own-ones"
      << std::endl;
    m.root->print(s, 0);
    return s;
  }


  bool Model::place_tile(const Tile& tile, Node*& parent, 
			 std::vector<Node*>& children,
			 Node* const& under) const
  {
    using std::vector;
    using std::cerr;
    using std::endl;

//     cerr << "place_tile(" << tile << ",...,...," << *(Tile*)under
// 	 << ")" << endl;
    assert(under->subsumes(tile));
    vector<Node*>::const_iterator i;
    bool found = false;
    for (i = under->children.begin(); i != under->children.end(); i++)
      {
	if (tile.subsumes(**i)) {
// 	  cerr << "subsumes " << *(Tile*)*i << endl;
	  found = true;
	  children.push_back(*i);
	} else if ((*i)->subsumes(tile)) {
// 	  cerr << "subsumed by " << *(Tile*)*i << endl;
	  assert(!found);
	  return place_tile(tile, parent, children, *i);
	} else if (tile.overlaps(**i)) {
// 	  cerr << "OVERLAPS " << *(Tile*)*i << endl;
	  return false;
	}
      }
    parent = under;
    return true;
  }

  void Model::add_tile(const Tile& t)
  {
    using namespace std;
    cout << "Adding tile " << t << endl;
    Node *nn = new Node(t, ones_within(t)), *parent;

    bool ok = place_tile(t, parent, nn->children, root);
    assert(ok);
    // Remove my children's coverage from myself
    for (vector<Node*>::iterator 
	   i = nn->children.begin();
	 i != nn->children.end();
	 i++)
      {
	nn->my_ones -= (*i)->ones;
	nn->my_size -= (*i)->size();
      }
    // Remove my coverage from my parent
    parent->my_ones -= nn->my_ones;
    parent->my_size -= nn->my_size;
    // Remove my children from my parent
    for (vector<Node*>::iterator
	   i = parent->children.begin();
	 i != parent->children.end();
	 /* increased in loop */)
      {
	if (find(nn->children.begin(), nn->children.end(), *i)
	    != nn->children.end())
	  i = parent->children.erase(i);
	else
	  i++;
      }
    // Add me as my parent's child
    parent->children.push_back(nn);

    tiles.push_back((Tile*)nn);
  }

  double
  Model::log_likelihood(const Node* node) const
  {
    double result = node->my_size * negent(1.0*node->my_ones/node->my_size);
    for (unsigned int i = 0; i < node->children.size(); i++)
      result += log_likelihood(node->children[i]);
    return result;
  }

  double Model::log_likelihood() const
  {
    return log_likelihood(root);
  }

  void Model::gnuplot(std::ostream* output, const Model::Node& node) const
  {
    *output << "plot '-' notitle with lines lt 1 lw 2" << std::endl;
    *output << node.edge[Tile::TOP] << ' ' 
	    << node.edge[Tile::LEFT] << std::endl;
    *output << node.edge[Tile::TOP] << ' ' 
	    << node.edge[Tile::RIGHT] << std::endl;
    *output << node.edge[Tile::BOTTOM] << ' ' 
	    << node.edge[Tile::RIGHT] << std::endl;
    *output << node.edge[Tile::BOTTOM] << ' ' 
	    << node.edge[Tile::LEFT] << std::endl;
    *output << node.edge[Tile::TOP] << ' ' 
	    << node.edge[Tile::LEFT] << std::endl;
    *output << 'e' << std::endl;
    for (std::vector<Model::Node*>::const_iterator
	   i = node.children.begin();
	 i != node.children.end();
	 i++)
      gnuplot(output, **i);
  }

  void Model::gnuplot(std::ostream* output, const char* name) const
  {
    using namespace std;
    *output << "set term post eps" << endl
	    << "set out '" << name << ".eps'" << endl
	    << "set size 1,1" << endl
	    << "set origin 0,0" << endl
	    << "set multiplot" << endl
	    << "set xrange [-0.5:" << data.dim1() - 0.5 << "]" << endl
	    << "set yrange [-0.5:" << data.dim2() - 0.5 << "]" << endl;
    *output << "plot '-' notitle with points pt 7 ps 0.5" << endl;
    for (size_t i = 0; i < data.dim1(); i++)
      for (size_t j = 0; j < data.dim2(); j++)
	if (data[i][j])
	  *output << i << ' ' << j << endl;
    *output << 'e' << endl;
    gnuplot(output, *root);
  }

  // How much would adding this tile increase the llh?
  double Model::llh_increase(const Tile& t) const
  {
    Node* parent;
    std::vector<Node*> children;
    bool ok = place_tile(t, parent, children, root);
    if (!ok)
      return -INFINITY;

//     std::cerr << '{' << parent->my_size << ' '
// 	      << parent->my_ones;
    double old = parent->my_size * negent(1.0*parent->my_ones/parent->my_size);
//     std::cerr << " -> " << old << ", ";
    long int my_ones = ones_within(t), my_size = t.size();
    for (std::vector<Node*>::const_iterator
	   i = children.begin();
	 i != children.end();
	 i++)
      {
	my_ones -= ones_within(**i);
	my_size -= (*i)->size();
      }
//     std::cerr << my_size << ' ' << my_ones << ", ";
    long int 
      parent_ones = parent->my_ones - my_ones, 
      parent_size = parent->my_size - my_size;
    if (my_size == 0 || parent_size == 0)
      return -INFINITY;
    double nw = 
      my_size * negent(1.0*my_ones/my_size)
      + parent_size * negent(1.0*parent_ones/parent_size);
//     std::cerr << parent_size << ' ' << parent_ones
// 	      << " -> " << nw 
// 	      << " => " << nw-old << '}' << std::endl;

    return nw-old;
  }


}
