#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Compute dense itemsets

"Frontier search" takes as input the support threshold s and the
number N of sets required. It finds N strongly dense itemsets that
have maximal strong density at support s.

The script requires Python 2.4.
"""

# $Id: dense.py,v 1.39 2006-06-14 12:46:07+03 jkseppan Exp $
__author__ = "Jouni K. Seppänen <jouni.seppanen@tkk.fi>"
__copyright__ = "Copyright (C) 2004-2006 Jouni K. Seppänen"
__license__ = "Boost Software License - Version 1.0 - August 17th, 2003"
__date__ = "2006-06-14"
__version__ = "$Revision: 1.39 $"
__test__ = {}

# To do:
#
# We assume elements are consecutive integers starting from zero;
#   it should be quite easy to accept arbitrary element names
# Optimize scanning in classic levelwise search
# Add documentation

from __future__ import division, generators

import getopt, string, sys, math

inf = float('inf')

from heapq import heappop, heappush

if False:
    try:
        import psyco
        psyco.profile()
        print >> sys.stderr, "* using psyco"
    except:
        print >> sys.stderr, "* psyco not found, program will be a little slower"


def dbscan(data, set):
    """Find "histogram" of set in data.
    If H=dbscan(DATA,SET), then for all i=0,...,len(set),
    H[i] is the number of rows in DATA that have exactly i
    elements of SET.
    """

    res = [0] * (len(set)+1)
    for row in data:
        res[len(set & row)] += 1
    return res

def dbscan_many(data, sets):
    """Find histograms of sets in data. The result is indexed first
    by set, second by intersection size."""

    res = []
    for set in sets:
        res.append([0] * (len(set)+1))
    nsets = len(sets)
    for row in data:
        for j in xrange(nsets):
            res[j][len(sets[j] & row)] += 1
    return res

def density(histogram, sigma):
    """Density, at support sigma, of a set that has the given histogram.
    """

    nrows = 0
    nitems = 0
    for j in range(len(histogram)-1,-1,-1):
        newrows = histogram[j]
        nrows += newrows
        nitems += j*newrows
        if nrows >= sigma:
            nitems -= j*(nrows-sigma)
            break
    if nrows < sigma:
        raise ValueError, \
              ("Histogram has insufficient rows for sigma=%d" % sigma)
    return nitems / sigma / (len(histogram)-1)

def support(histogram, delta):
    """Support, at density delta, of a set that has the given histogram.
    """

    threshold = (len(histogram)-1)*delta
    nrows, nitems = 0, 0
    j = len(histogram) - 1

    while histogram[j] == 0:
        j -= 1
        if j < 0:
            raise ValueError, 'support: empty histogram?'

    while j >= 0:
        newrows = histogram[j]
        nrows += newrows
        nitems += j*newrows
        d = nitems / nrows
        if d < threshold:
            break
        j -= 1

    if d < threshold:                   # back up some
        nrows -= newrows
        nitems -= j*newrows
        lo,hi = 0,newrows               # d(nrows+lo) >= delta > d(nrows+hi)
        while hi-lo > 1:
            mid = (lo+hi)//2
            d = (nitems + j*mid)/(nrows+mid)
            if d >= threshold:
                lo = mid
            else:
                hi = mid
        nrows += lo
    
    return nrows


def readdata(file):
    """Read data from given file handle."""

    data = []
    for line in file:
        words = string.split(line)
        if words == ['']:
            data.append(set([]))
        else:
            data.append(set([int(w) for w in words]))
    return data

class PrefixTreeNode(object):
    __slots__ = ('dict', 'exists', 'content')

    def __init__(self):
        self.dict = {}
        self.exists = False
        self.content = None

    def insert(self, item, what):
        self.dict[item] = what

    def child(self, item):
        return self.dict[item]

class SetFamily(object):
    """Family of itemsets with associated information."""
    __slots__ = ('_SetFamily__root','len')

    def __init__(self):
        self.__root = PrefixTreeNode()
        self.len = 0

    def __len__(self):
        return self.len

    def __iter__(self):
        return self.traverse()

    def __contains__(self, set):
        return (self.find(set))[1]

    def find(self, set_):
        """Finds the set. Returns the information, and as a second value, a Boolean
        indicating whether the set was found (to tell the difference between
        nil content and set not found)."""
        pos = self.__root
        try:
            itemlist = sorted(list(set_))
            for item in itemlist:
                pos = pos.child(item)
            return pos.content, pos.exists
        except KeyError:
            return None, False

    def traverse(self):
        """Generator that yields all sets in the family."""

        return self.__traverse_recursive(self.__root, set())

    def __traverse_recursive(self, pos, set_):
        if pos.exists:
            yield set_, pos.content
        for item in pos.dict:
            for result in \
                self.__traverse_recursive(pos.child(item), set_ | set([item])):
                yield result

    def insert(self, set_, info):
        """Insert set along with the given information.
        If set already exists in family, overwrite content."""
        pos = self.__root
        itemlist = sorted(list(set_))
        try:
            for i,item in enumerate(itemlist):
                pos = pos.child(item)
        except KeyError:
            for j in range(i,len(itemlist)):
                new = PrefixTreeNode()
                pos.insert(itemlist[j], new)
                pos = new
        pos.exists = True
        pos.content = info
        self.len += 1

    def subset_of(self, query):
        """Finds some subset of query. Returns subset."""
        try:
            itemlist = sorted(list(query))
            return self.__subsets_recursive(self.__root, set([]),
                                            itemlist).next()[0]
        except StopIteration:
            return None

    def traverse_subsets(self, query):
        """Generator that yields all subsets of query, and their associated
        content.
        """
        itemlist = sorted(list(query))
        return self.__subsets_recursive(self.__root, set([]), itemlist)

    def __subsets_recursive(self, position, set_, itemlist, this_done=False):
        """Helper function for traverse_subsets."""
        if (not this_done) and position.exists:
            yield set_, position.content
        if len(itemlist) == 0:
            return
        item = itemlist[0]
        try:
            child = position.child(item)
            for result in self.__subsets_recursive(child, set_ | set([item]),
                                                   itemlist[1:], False):
                yield result
        except KeyError:
            pass
        for result in self.__subsets_recursive(position, set_,
                                               itemlist[1:], True):
            yield result

    def superset_of(self, query):

        """Finds some superset of query. NB: Does not necessarily
        return a set in the family but some subset of a set in the
        family that is a superset of query.
        """
        l=len(query)
        if l == 0:  # Need to special-case this because the root always exists.
            if self.__root.dict or self.__root.exists:
                return ()
            else:
                return None
        itemlist = sorted(list(query))
        return self.__superset_recursive(self.__root, set([]), itemlist, len(query))

    def __superset_recursive(self, position, set_, itemlist, l):
        if l == 0:
            return set_
        q0 = itemlist[0]
        try:
            si = sorted(list(set_))
            s0 = si[-1]
        except IndexError:
            s0 = 0                      # XXX should be global minelem?
        if q0 in position.dict:
            return self.__superset_recursive(
                position.child(q0), set_ | set([q0]), itemlist[1:], l-1)
        for j in range(s0+1,q0):
            if j in position.dict:
                superset = self.__superset_recursive(
                    position.child(j), set_ | set([j]), itemlist, l)
                if superset:
                    return superset
        return None


class FrontierScanner(object):
    __slots__ = ('data', 'frontier', 'candidates', 'result', 'itemrange')

    def __init__(self, data):
        self.data,self.result = data,SetFamily()
        all_items = set([])
        for d in data:
            all_items |= d
        self.itemrange = sorted(list(all_items))
        self.candidates = []

    def create(data, sigma=None, delta=None, nsets=None):
        if sigma is not None and delta is not None and nsets is not None:
            raise ValueError, "Overspecified"

        if sigma is not None and nsets is not None:
            # Find NSETS sets at support SIGMA
            return DensityFrontierScanner(data=data, sigma=sigma, nsets=nsets)
        elif delta is not None and nsets is not None:
            # Find NSETS sets at density DELTA
            return SupportFrontierScanner(data=data, delta=delta, nsets=nsets)
        elif sigma is not None and delta is not None:
            # Find sets that have support SIGMA at density DELTA
            return ClassicFrontierScanner(data=data, sigma=sigma, delta=delta)
        else:
            raise ValueError, \
                  "Need to specify two out of the following three:\n"+\
                  "support threshold, density threshold, number of itemsets"
    create = staticmethod(create)

    def explore(self):
        raise NotImplementedError, "FrontierScanner is an abstract class"

    def run_candidates(self):
        for c in self.candidates:
            self.explore(c)
        self.candidates = []

    def scan(self):
        for item in self.itemrange:
            self.candidates.append(set([item]))
            self.run_candidates()

        nfound = 0
        while self.moresets(nfound):
            (set_, data) = self.nextset()
            self.result.insert(set_, data)
            nfound += 1
            print >> sys.stderr, "* %s [%f] (%d in frontier)" % \
                  (list(set_), data, len(self.frontier))

            for item in self.itemrange:
                if item in set_:
                    continue
                candidate = set_ | set([item])

                for removed in candidate:
                    if removed == item:
                        continue
                    subset = candidate - set([removed])
                    if subset not in self.result:
                        break
                else:
                    self.candidates.append(candidate)
            if self.candidates:
                self.run_candidates()

        return self.result


class HeapFrontierScanner(FrontierScanner):
    __slots__ = ('minimum','finalizing','otherparamname')
    
    def __init__(self, data):
        super(HeapFrontierScanner,self).__init__(data)
        self.frontier = []
        self.minimum = inf
        self.finalizing = False
        self.otherparamname = '*you should not see this*'

    def run_candidates(self):
        hists = dbscan_many(self.data, self.candidates)
        results = map(self.histfun, hists)
        for i in range(len(hists)):
            heappush(self.frontier, (-results[i], self.candidates[i]))
        self.candidates = []

    def moresets(self, nfound):
        if not bool(self.frontier):
            return False
        if nfound < self.nsets:
            return True
        if not self.finalizing:
            self.finalizing = True
            print >> sys.stderr, "*** %s now fixed at %f" % \
                  (self.otherparamname, self.minimum)
        if -self.frontier[0][0] < self.minimum:
            return False
        else:
            return True

    def nextset(self):
        (res,found) = heappop(self.frontier)
        res = -res
        if not self.finalizing:
            self.minimum = min(self.minimum, res)
        return (found,res)

    def histfun(hist):
        return None

class SupportFrontierScanner(HeapFrontierScanner):
    __slots__ = ('delta','nsets')
    
    def __init__(self, data, delta, nsets):
        super(SupportFrontierScanner,self).__init__(data)
        self.delta,self.nsets = delta,nsets
        self.otherparamname = 'sigma'

    def histfun(self, hist):
        return support(hist, self.delta)

class DensityFrontierScanner(HeapFrontierScanner):
    __slots__ = ('sigma','nsets')
    
    def __init__(self, data, sigma, nsets):
        super(DensityFrontierScanner,self).__init__(data)
        self.sigma,self.nsets = sigma,nsets
        self.otherparamname = 'delta'

    def histfun(self, hist):
        return density(hist, self.sigma)

class ClassicFrontierScanner(FrontierScanner):
    __slots__ = ('delta','sigma')
    
    def __init__(self, data, sigma, delta):
        super(ClassicFrontierScanner,self).__init__(data)
        self.sigma,self.delta = sigma,delta
        self.frontier = [] # stack

    def moresets(self,nfound):
        return bool(self.frontier)

    def explore(self, set_):
        hist = dbscan(self.data, set_)
        d = density(hist, self.sigma)
        if d >= self.delta:
            self.frontier.append((d,set))

    def nextset(self):
        (s,found) = self.frontier.pop()
        return (found, s)

class LevelwiseScanner(object):
    __slots__ = ('data', 'result', 'itemrange', 'sigma', 'delta')

    def __init__(self, data, sigma, delta):
        self.data,self.result = data,SetFamily()
        all_items = set([])
        for d in data:
            all_items |= d
        self.itemrange = sorted(list(all_items))
        self.sigma = sigma
        self.delta = delta

    def scan(self):
        candidates = [ set([item]) for item in self.itemrange ]

        level = 1
        while len(candidates) > 0:
            print >> sys.stderr, "* level %d: %d candidates" % \
                  (level, len(candidates))
            passed = []
            histograms = dbscan_many(self.data, candidates)
            for i,hist in enumerate(histograms):
                d = density(hist, self.sigma)
                if d >= self.delta:
                    self.result.insert(candidates[i], d)
                    passed.append(candidates[i])
            print >> sys.stderr, "*  %d passed" % len(passed)
#             print "== PASSED =="
#             for c in passed:
#                 print string.join(map(repr,c.items()))
            level += 1
            candidates = self.candgen(passed, level)

        return self.result

    def candgen(sets, newlevel):
        if newlevel == 2:
            return LevelwiseScanner.candgen2(sets)

        sets.sort(key=lambda s: sorted(list(s)))
        nsets = len(sets)
        precandidates = []
        i = 0
        while i < nsets-1:
            set_i = sets[i]
            intersection = set_i - set([max(set_i)])
            for j in range(i+1,nsets):
                set_j = sets[j]
                if set_i & set_j != intersection:
                    break
            # Now sets[i:j] (note that this is a half-open interval)
            # all have intersection size newlevel-2, so all pairs are
            # possible candidates
            for k in range(i,j-1):
                for l in range(k+1,j):
                    union = sets[k] | sets[l]
                    assert len(union) == newlevel
                    precandidates.append(union)
            i = j

#         print "== PRECANDIDATES =="
#         for c in precandidates:
#             print string.join(map(repr,c.items()))

        pre2candidates = []
        precandidates.sort()
        last = None
        for set_ in precandidates:
            if set_ != last:
                pre2candidates.append(set_)
                last = set_
        precandidates = []

#         print "== PRE2CANDIDATES =="
#         for c in pre2candidates:
#             print string.join(map(repr,c.items()))

        candidates = []
        for set_ in pre2candidates:
            good = True
            for item in set_:
                if set_ - set([item]) not in sets:
                    good = False
                    break
            if good:
                candidates.append(set_)

#         print "== CANDIDATES =="
#         for c in candidates:
#             print string.join(map(repr,c.items()))

        return candidates
    candgen = staticmethod(candgen)

    def candgen2(sets):
        result = []
        nsets = len(sets)
        for i in xrange(nsets):
            for j in xrange(i+1,len(sets)):
                result.append(sets[i] | sets[j])
        return result
    candgen2 = staticmethod(candgen2)


def main(argv):
    try:
        (optionlist, argv) = getopt.getopt(argv, 'n:s:d:')
    except GetoptError:
        argv = []
    options = dict(optionlist)
    argc = len(argv)
    
    if argc == 1:
        filename = argv[0]
        file = open(filename)
        data = readdata(file)
        file.close()
        print >> sys.stderr, "* %d line%s read" % \
              (len(data), (len(data)!=1 and "s") or "")

        nsets,sigma,delta=None,None,None
        if '-n' in options: nsets = int(options['-n'])
        if '-s' in options: sigma = int(options['-s'])
        if '-d' in options: delta = float(options['-d'])

        if '-n' not in options:
            assert sigma
            assert delta
            scanner = LevelwiseScanner(data, sigma, delta)
        else:
            scanner = FrontierScanner.create(
                data=data, nsets=nsets, sigma=sigma, delta=delta)

        result = scanner.scan()
            
        print >> sys.stderr, "* %d set%s found" % \
              (len(result), (len(result)!=1 and "s") or "")

        for (set_, density) in result:
            #output = "=\t%f\t%f" % (density, gap(set,result,density))
            #output += "\t" + (",".join(map(repr,set.items())))
            output = "=\t%f\t" % density
            output += "%f\t" % min([d for s,d in result.traverse_subsets(set_)])
            items = list(set_)
            items.sort()
            output += ",".join(map(repr,items))
            print output
    else:
        print "dense.py", __version__
        print __copyright__
        print """
This program comes with ABSOLUTELY NO WARRANTY.
This is free software, and you are welcome to redistribute it
under certain conditions. See the file LICENSE for details
on the lack of warranty and your right to redistribute.
        
Usage: dense.py [-n nsets] [-s sigma] [-d delta] datafile
      (specify any two of nsets, sigma, delta)

  Data format: items are 0-based integers,
  a record consists of space-separated items
  and is terminated by newline

  Output format: lines beginning with * are debug info;
  lines beginning with = are actual output
"""

if __name__ == '__main__':
    main(sys.argv[1:])

