# -*- coding: utf-8 -*-
# vi:ts=4:et
#
# $Date: 2003/08/15 12:35:37 $
# $Revision: 1.4 $
# =====================================================

"""library for partition"""

# Return all the partitions of the number N.
# partition(3) -> (1,1,1), (2,1), (3,)

import sys
import itertools

# young
import mathformat
import iterator
import error

__all__     = ['partition']

# cache partitions
_cache = {}


class Partition(tuple):
    """class for partition
    """
    def size(self):
        # XXX
        # do we need this?
        """Get the number of partition
        """
        return sum(self)

    def __repr__(self):
        return mathformat.pprint_partition(self)

    __str__ = __repr__

    def plot(self):
        """Plot the partition
        """
        # (3,1,1)
        # -->
        # ***
        # *
        # *
        for num in self:
            print "*" * num
        else:
            print


class SeqOfPartition(iterator.Iterator):
    """class for generating partitions
    """
    def __init__(self, number):

        self.number = number
        self.seq    = []
        if number:
            self.set_partition()

    def __getitem__(self, index):
        return self.seq[index]

    def size(self):
        """Return the number of partitions
        """
        # Mathematica
        # Length[]
        return len(self.seq)

    def get_partition(self):
        """Return the partition
        """
        return self.seq

    def __iter__(self):
        return iter(self.seq)

    def set_partition(self):
        """Set partition
        """
        try:
            self.seq = _cache[self.number]
        except KeyError, e:
            for i in range(1, self.number+1):
                try:
                    self.seq = _cache[i]
                except KeyError, e:
                    self.seq = _cache[i] = [Partition(pat) for pat in self.generate(i)]

    def plot(self):
        """Plot each partition
        """
        for pat in self.seq:
            pat.plot()

    def generate(self, n):
        """Partition generator
        """
        # trivial
        yield (n,)

        head, tail = n-1, 1

        while head > 0:
            if tail > 1:
                partial_result = _cache.get(tail, [])

                for candidate in partial_result[:]:
                    if _compare_max_num(head, candidate):
                        yield (head,) + candidate
            elif tail == 1:
                # tail == 1
                yield (head, tail)
            else:
                # tail == 0
                yield (head, )
            head, tail = head-1, tail+1

    def show_all_partition(self):
        """Show all the partitions in the cache
        """
        return show_partition(True)

    def show_partition(self, all = False):
        """Show the partition
        If all == False,
        show each partition of the number.
        If all == True,
        show all the partitions in the cache.
        """
        print self._show_partition(all)

    def _show_partition(self, all = False):
        buff = []
        if not all:
            buff = map(mathformat.pprint_partition, self.seq)
        else:

            for i in range(1, self.number+1):

                k, v = i, _cache[i]
                buff.append(`k`)
                for item in v:
                    buff.append("\t" + mathformat.pprint_partition(item))
        return "\n".join(buff)

    def __repr__(self):
        return self._show_partition()

    __str__ = __repr__


def _compare_max_num(number, part):
    """Tests whether or not the number is equal to or greater than
    any number of the partition."""
    # In short, if the partition is in order.
    # you only need to check if number >= part[0].

    # number = 3
    # part = (2,1)
    # --> OK

    # number = 2
    # part = (2,1,1,1)
    # --> OK

    # number = 3
    # part = (4,2,1)
    # --> NG

    return number >= part[0]


def partition(number):
    """Return the partition of number
    """
    # argument validity check
    if not isinstance(number, int) and number >= 0:
        raise error.ArgumentError, ("number must be a non-negative integer.", number )

    #seq_of_part = SeqOfPartition(number)
    #seq_of_part.set_partition()
    return SeqOfPartition(number)

