// Copyright 1997 by Patrik Simons
// This software is provided as is, no warranty of any kind is given.
#include <iostream.h>
#include <iomanip.h>
#include <unistd.h>
#include <limits.h>
#include <float.h>
#include <stdlib.h>
#include "print.h"
#include "atomrule.h"
#include "smodels.h"


Smodels::Smodels ()
  : dcl (this)
{
  atom = 0;
  atomtop = 0;
  fail = false;
  guesses = 0;
  conflict_found = false;
  sloppy_heuristic = false;
  answer_number = 0;
  number_of_choice_points = 0;
  number_of_wrong_choices = 0;
  number_of_backjumps = 0;
  number_of_picked_atoms = 0;
  number_of_forced_atoms = 0;
  number_of_assignments = 0;
  max_models = 0;
  setup_top = 0;
}

Smodels::~Smodels ()
{
  delete[] atom;
}

void
Smodels::init ()
{
  program.init ();
  dcl.init ();
  atom = new Atom *[program.number_of_atoms];
  for (Node *n = program.atoms.head; n; n = n->next)
    atom[atomtop++] = n->atom;
  stack.Init (program.number_of_atoms);
  depends.Init (program.number_of_atoms);
  lasti = 0;
}

inline void
Smodels::removeAtom (long n)
{
  assert (atomtop);
  atomtop--;
  Atom *a = atom[atomtop];
  atom[atomtop] = atom[n];
  atom[n] = a;
}

inline void
Smodels::addAtom (Atom *a)
{
  while (atom[atomtop] != a)
    atomtop++;
  atomtop++;
  assert (atomtop <= program.number_of_atoms);
}

void
Smodels::resetDependency ()
{
  while (!depends.empty ())
    {
      Atom *a = depends.pop ();
      a->dependsonTrue = false;
      a->dependsonFalse = false;
    }
}

void
Smodels::set_conflict ()
{
  conflict_found = true;
}

void
Smodels::setToBTrue (Atom *a)
{
  if (a->isnant)
    score++;
  if (!(a->dependsonTrue || a->dependsonFalse))
    depends.push (a);
  a->dependsonTrue = true;
  a->setBTrue ();
  stack.push (a);
}

void
Smodels::setToBFalse (Atom *a)
{
  if (a->isnant)
    score++;
  if (!(a->dependsonTrue || a->dependsonFalse))
    depends.push (a);
  a->dependsonFalse = true;
  a->setBFalse ();
  stack.push (a);
}

void
Smodels::pick ()
{
  number_of_picked_atoms++;
  number_of_choice_points++;
  Atom *picked = 0;
  long pn = 0;
  for (long i = 0; i < atomtop; i++)
    {
      Atom *a = atom[i];
      if (a->Bpos || a->Bneg)
	{
	  removeAtom (i);
	  i--;
	}
      else if (a->isnant)
	{
	  picked = a;
	  pn = i;
	  break;
	}
    }
  PRINT_PICK (cout << "Picked " << picked->atom_name () << endl);
  stack.push (picked);
  picked->guess = true;
  guesses++;
  removeAtom (pn);
  picked->setBFalse ();
  cnflct = 0;
}

inline void
Smodels::expand ()
{
  long top = stack.top;
  dcl.dcl ();
  number_of_assignments += stack.top - top + 1; // One for the choice
}

bool
Smodels::conflict ()
{
  if (conflict_found)
    {
      PRINT_CONFLICT (cout << "Conflict" << endl);
      program.queue.reset ();
      while (!program.equeue.empty ())
	{
	  Atom *a = program.equeue.pop ();
	  a->in_etrue_queue = false;
	  a->in_efalse_queue = false;
	}
      conflict_found = false;
      return true;
    }
  for (OptimizeRule *r = program.optimize; r; r = r->next)
    if (r->maximize)
      {
	if (r->maxweight < r->maxoptimum)
	  return true;
	else if (r->maxweight > r->maxoptimum)
	  return false;
      }
    else if (r->minweight > r->minoptimum)
      return true;
    else if (r->minweight < r->minoptimum)
      return false;
  return false;
}

void
Smodels::setup ()
{
  dcl.setup ();
  if (conflict_found) // Can't use conflict() as this removes the conflict.
    return;
  expand ();  // Compute well-founded model
  number_of_assignments--;
  dcl.reduce (true);  // Reduce the program strongly
  // Initialize literals chosen to be in the stable model / full set.
  long top = stack.top;
  for (Node *n = program.atoms.head; n; n = n->next)
    {
      Atom *a = n->atom;
      // Negative literal
      if (a->computeFalse && a->Bneg == false)
	{
	  if (a->Bpos)
	    {
	      set_conflict ();
	      return;
	    }
	  setToBFalse (a);
	}
      // Positive literal
      if (a->computeTrue && a->Bpos == false)
	{
	  if (a->Bneg)
	    {
	      set_conflict ();
	      return;
	    }
	  setToBTrue (a);
	}
    }
  number_of_assignments += stack.top - top;
  expand ();
  number_of_assignments--;
  if (conflict_found)
    return;
  dcl.reduce (false); // Reduce weakly
  dcl.improve ();
  for (Node *n = program.atoms.head; n; n = n->next)
    {
      n->atom->dependsonTrue = false;
      n->atom->dependsonFalse = false;
    }
  depends.reset ();
  setup_top = stack.top;
  number_of_assignments--;
  level = 0;
  cnflct = 0;
}

void
Smodels::revert ()
{
  // Revert to before the setup call.
  unwind_to_setup ();
  dcl.unimprove ();
  dcl.unreduce ();
  unwind_all ();
  dcl.revert ();
  fail = false;
  conflict_found = false;
  answer_number = 0;
  number_of_choice_points = 0;
  number_of_wrong_choices = 0;
  number_of_backjumps = 0;
  number_of_picked_atoms = 0;
  number_of_forced_atoms = 0;
  number_of_assignments = 0;
}

inline bool 
Smodels::covered ()
{
  return stack.full ();
}

Atom *
Smodels::unwind ()
{
  Atom *a = stack.pop ();
  while (a->guess == false)
    {
      PRINT_STACK (a->backtracked = false; a->forced = false);
      if (a->Bpos)
	a->backtrackFromBTrue ();
      else if (a->Bneg)
	a->backtrackFromBFalse ();
      a = stack.pop ();
    }
  a->guess = false;
  guesses--;
  return a;
}

void
Smodels::unwind_all ()
{
  Atom *a;
  while (!stack.empty ())
    {
      a = stack.pop ();
      if (a->guess)
	{
	  a->guess = false;
	  guesses--;
	}
      PRINT_STACK (a->backtracked = false; a->forced = false);
      if (a->Bpos)
	a->backtrackFromBTrue ();
      else if (a->Bneg)
	a->backtrackFromBFalse ();
    }
  atomtop = program.number_of_atoms;
  setup_top = 0;
}

void
Smodels::unwind_to_setup ()
{
  Atom *a;
  while (stack.top > setup_top)
    {
      a = stack.pop ();
      if (a->guess)
	{
	  a->guess = false;
	  guesses--;
	}
      PRINT_STACK (a->backtracked = false; a->forced = false);
      if (a->Bpos)
	a->backtrackFromBTrue ();
      else if (a->Bneg)
	a->backtrackFromBFalse ();
    }
  atomtop = program.number_of_atoms;
}

Atom *
Smodels::backtrack ()
{
  if (guesses == 0)
    {
      fail = true;
      return 0;
    }
  Atom *a = unwind ();
  PRINT_BACKTRACK (cout << "Backtracking: " << a->atom_name () << endl);
  if (a->Bneg)
    {
      a->backtrackFromBFalse ();
      a->setBTrue ();
      stack.push (a);
      PRINT_STACK (a->backtracked = true);
    }
  else
    {
      a->backtrackFromBTrue ();
      a->setBFalse ();
      stack.push (a);
      PRINT_STACK (a->backtracked = true);
    }
  return a;
}

Atom *
Smodels::backjump ()
{
  for (;;) 
    {
      if (guesses == 0)
	{
	  fail = true;
	  return 0;
	}
      Atom *a = unwind ();
      PRINT_BACKTRACK (cout << "Backtracking: " << a->atom_name () << endl);
      bool b = a->Bneg;
      if (a->Bneg)
	{
	  a->backtrackFromBFalse ();
	  PRINT_STACK (a->backtracked = false; a->forced = false);
	}
      else
	{
	  a->backtrackFromBTrue ();
	  PRINT_STACK (a->backtracked = false; a->forced = false);
	}
      if (cnflct == 0 || guesses < level || dcl.path (cnflct,a))
	{
	  if (guesses < level)
	    level = guesses;
	  if (b)
	    a->setBTrue ();
	  else
	    a->setBFalse ();
	  stack.push (a);
	  cnflct = a;
	  PRINT_STACK (a->backtracked = true);
	  return a;
	}
      number_of_backjumps++;
    }
}

bool
Smodels::testPos (Atom *a)
{
  score = 0;
  stack.push (a);
  a->guess = true;
  guesses++;
  a->setBTrue ();
  number_of_picked_atoms++;
  expand ();
  if (conflict ())
    {
      // Backtrack puts the atom onto the stack.
      number_of_forced_atoms++;
      backtrack ();
      cnflct = a;
      PRINT_STACK (a->forced = true);
      return true;
    }
  unwind ();
  a->backtrackFromBTrue ();
  return false;
}

bool
Smodels::testNeg (Atom *a)
{
  score = 0;
  stack.push (a);
  a->guess = true;
  guesses++;
  a->setBFalse ();
  number_of_picked_atoms++;
  expand ();
  if (conflict ())
    {
      // Backtrack puts the atom onto the stack.
      number_of_forced_atoms++;
      backtrack ();
      cnflct = a;
      PRINT_STACK (a->forced = true);
      return true;
    }
  unwind ();
  a->backtrackFromBFalse ();
  return false;
}

void
Smodels::testScore (long pos, long neg, long i, long &hiscore1,
	      long &hiscore2, long &hii, int &positive)
{
  long mn, mx;
  if (neg < pos)
    {
      mn = neg;
      mx = pos;
    }
  else
    {
      mn = pos;
      mx = neg;
    }
  if (mn > hiscore1)
    {
      hiscore1 = mn;
      hiscore2 = mx;
      hii = i;
      if (mn == pos)
	positive = 0;
      else
	positive = 1;
    }
  else if (mn == hiscore1)
    if (mx >= hiscore2)
      {
	hiscore2 = mx;
	hii = i;
	if (mn == pos)
	  positive = 0;
	else
	  positive = 1;
      }
}

//
// Choose a literal that gives rise to a conflict. 
// If no such literal exists we choose the literal 
// that brings the most atoms into the closures.
//
void
Smodels::lookahead ()
{
  long hiscore1 = 0;
  long hiscore2 = 0;
  int positive = 0;
  long hii = -1;
  long i;
  Atom *a;
  
  resetDependency ();
  bool firstPass = true;
  for (i = lasti; ;)
    {
      if (firstPass)
	{
	  if (i >= atomtop)
	    {
	      firstPass = false;
	      i = 0;
	      continue;
	    }
	}
      else if (i >= atomtop || i >= lasti)
	break;
      a = atom[i];
      if (a->Bpos || a->Bneg)
	{
	  removeAtom (i);
	  if (!firstPass)
	    {
	      if (hii == atomtop)
		hii = i;
	      if (atomtop >= lasti)
		i++;
	    }
	  continue;
	}
      if (a->dependsonFalse == false)
	{
	  if (testNeg (a))
	    {
	      lasti = i;
	      removeAtom (i);
	      return;
	    }
	  else
	    a->negScore = score;
	}
      else
	a->negScore = -1;
      if (a->dependsonTrue == false)
	{
	  if (testPos (a))
	    {
	      lasti = i;
	      removeAtom (i);
	      return;
	    }
	  else
	    a->posScore = score;
	}
      else
	a->posScore = -1;
      if (a->posScore != -1 && a->negScore != -1)
	testScore (a->posScore, a->negScore, i, hiscore1, hiscore2,
		   hii, positive);
      i++;
    }
  for (i = 0; i < atomtop; i++)
    {
      a = atom[i];
      if (a->negScore >= 0 && a->negScore <= hiscore1)
      	continue;
      if (a->posScore >= 0 && a->posScore <= hiscore1)
      	continue;
      if (a->negScore == -1)
	if (sloppy_heuristic)
	  a->negScore = 1;
	else
	  {
	    testNeg (a);
	    a->negScore = score;
	  }
      if (a->negScore <= hiscore1)
      	continue;
      if (a->posScore == -1)
	if (sloppy_heuristic)
	  a->posScore = 1;
	else
	  {
	    testPos (a);
	    a->posScore = score;
	  }
      testScore (a->posScore, a->negScore, i, hiscore1, hiscore2, hii,
		 positive);
    }
  assert (hii >= 0);
  a = atom[hii];
  stack.push (a);
  a->guess = true;
  guesses++;
  removeAtom (hii);
  cnflct = 0;
  if (positive)
    a->setBTrue ();
  else
    a->setBFalse ();
  number_of_picked_atoms++;
  number_of_choice_points++;
  PRINT_PICK (if (positive)
	      cout << "Chose " << a->atom_name () << endl;
	      else
	      cout << "Chose not " << a->atom_name () << endl);
}

void
Smodels::backtrack (bool jump)
{
  Atom *a;
  if (jump)
    a = backjump ();
  else
    a = backtrack ();
  if (a)
    addAtom (a);
  number_of_wrong_choices++;
}

int
Smodels::smodels (bool look, bool jump)
{
  setup ();
  if (conflict ())
    return 0;
  while (!fail)
    {
      PRINT_DCL (cout << "Smodels call" << endl;
		 dcl.print ());
      PRINT_BF (print ());
      expand ();
      PRINT_BF (cout << "Expand" << endl);
      PRINT_DCL (dcl.print ());
      PRINT_BF (print ());
      PRINT_PROGRAM(printProgram ());
      PRINT_STACK (printStack ());
      if (conflict ())
	backtrack (jump);
      else if (covered ())
	{
	  answer_number++;
	  level = guesses;
	  cout << "Answer: " << answer_number << endl;
	  printAnswer ();
	  for (OptimizeRule *r = program.optimize; r; r = r->next)
	    r->setOptimum ();
	  if (max_models && answer_number >= max_models)
	    return 1;
	  else
	    backtrack (jump);
	}
      else if (look)
	lookahead ();
      else
	pick ();
    }
  number_of_wrong_choices--;
  return 0;
}

int
Smodels::model (bool look, bool jump)
{
  if (answer_number)
    backtrack (jump);
  else
    {
      setup ();
      if (conflict ())
	return 0;
    }
  while (!fail)
    {
      expand ();
      if (conflict ())
	backtrack (jump);
      else if (covered ())
	{
	  answer_number++;
	  level = guesses;
	  for (OptimizeRule *r = program.optimize; r; r = r->next)
	    r->setOptimum ();
	  return 1;
	}
      else if (look)
	lookahead ();
      else
	pick ();
    }
  number_of_wrong_choices--;
  return 0;
}

int
Smodels::smodels (bool look, bool jump, long max_conflicts)
{
  long conflicts = 0;
  while (!fail)
    {
      expand ();
      if (conflict ())
	{
	  backtrack (jump);
	  conflicts++;
	}
      else if (covered ())
	{
	  answer_number++;
	  level = guesses;
	  cout << "Answer: " << answer_number << endl;
	  printAnswer ();
	  for (OptimizeRule *r = program.optimize; r; r = r->next)
	    r->setOptimum ();
	  if (max_models && answer_number >= max_models)
	    {
	      unwind_to_setup ();
	      return 1;
	    }
	  else
	    {
	      backtrack (jump);
	      conflicts++;
	    }
	}
      else if (conflicts >= max_conflicts)
	break;  // Must do expand after backtrack to retain
		// consistency
      else if (look)
	lookahead ();
      else
	pick ();
    }
  unwind_to_setup ();
  if (fail)
    {
      number_of_wrong_choices--;
      return -1;
    }
  return 0;
}

int
Smodels::wellfounded ()
{
  setup ();
  if (conflict ())
    return 0;
  cout << "Well-founded model: " << endl;
  cout << "Positive part: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bpos && n->atom->name)
      cout << n->atom->name << ' ';
  cout << endl << "Negative part: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bneg && n->atom->name)
      cout << n->atom->name << ' ';
  cout << endl;
  return 1;
}


void
Smodels::shuffle ()
{
  Atom *t;
  long i, r;
  for (i = 0; i < program.number_of_atoms; i++) {
    t = atom[i];
    // If the low order bits aren't as random as the high order bits,
    // your random number generator is broken.
    r = rand ()%(program.number_of_atoms-i)+i;
    atom[i] = atom[r];
    atom[r] = t;
  }
}

void
Smodels::print ()
{
  cout << "Body: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bpos)
      cout << n->atom->atom_name () << ' ';
  cout << endl << "NBody: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bneg)
      cout << n->atom->atom_name () << ' ';
  cout << endl << "Pick: ";
  Atom **a;
  for (a = stack.stack; a != stack.stack+stack.top; a++)
    if ((*a)->guess)
      if((*a)->Bneg)
	cout << "not " << (*a)->atom_name () << ' ';
      else
	cout << (*a)->atom_name () << ' ';
  cout << endl;
}

void
Smodels::printAnswer ()
{
  // Prints the stable model.
  cout << "Stable Model: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bpos && n->atom->name)
      cout << n->atom->name << ' ';
  cout << endl << "Full set: ";
  for (Node *n = program.atoms.head; n; n = n->next)
    if (n->atom->Bneg && n->atom->isnant && n->atom->name)
      cout << n->atom->name << ' ';
  cout << endl;
  for (OptimizeRule *r = program.optimize; r; r = r->next)
    {
      cout << "{ ";
      int comma = 0;
      for (Atom **a = r->pbody; a != r->pend; a++)
	if ((*a)->name)
	  {
	    if (comma)
	      cout << ", ";
	    cout << (*a)->name;
	    comma = 1;
	  }
      for (Atom **a = r->nbody; a != r->nend; a++)
	if ((*a)->name)
	  {
	    if (comma)
	      cout << ", ";
	    cout << "not " << (*a)->name;
	    comma = 1;
	  }
      cout << " } ";
      cout.precision (DBL_DIG);
      if (r->maximize)
	cout << "max = " << r->maxweight << endl;
      else
	cout << "min = " << r->minweight << endl;
    }
}

void
Smodels::printStack ()
{
  long i;
  cout << "\x1b[1;1fStack: ";
  for (i = 0; i < stack.top; i++)
    {
      Atom *a = stack.stack[i];
      if (a->forced)
	cout << "\x1b[31m";
      else if (a->backtracked)
	cout << "\x1b[32m";
      else if (a->guess)
	cout << "\x1b[34m";
      if (a->Bneg)
	cout << "not " << a->atom_name ();
      else
	cout << a->atom_name ();
      cout << "\x1b[0m ";
    }
  cout << "\x1b[0J" << endl;
  //  sleep(1);
}
