# -*- coding: utf-8 -*-
# vi:ts=4:et
#
# $Date: 2003/08/15 12:57:39 $
# $Revision: 1.5 $
# =====================================================

"""library for Young tableaux"""

import sys
import copy
import sets
import operator
import itertools

# young
import partition
import combination
import mathformat
import iterator
import error


__all__ = [ 'young', ]

# cache tableaux
# {key : value} = {partition : tableaux}
_cache = {}

class Young(list):
    """Class for Young
    """
    def __repr__(self):
        return mathformat.pprint_tableaux(self)

    __str__ = __repr__

class YoungGenerator(list):
    """
    Based on partitions of N,
    generate all Young diagrams
    """
    # 自然数N の分割数(partition)を元にして、
    # Young 図形を作成する。

    __slots__ = (
            'shape', # partition
            'size',  # size
            'tableaux',
            )

    def __init__(self, shape, size=0):
        self.shape= shape[:]  # shape := shapeition

        if not size:
            # (3,2,1,1) -> 3 + 2 + 1 + 1 = 7
            self.size = sum(shape)
        else:
            self.size = size

        self.tableaux = self.tableaux_initializer()

    def tableaux_initializer(self):
        """Based on a shape, create the container of the tableaux
        """
        # (3,1,1) -> [[None, None, None],[None], [None]]
        return [[None]*num for num in self.shape]

    def get_empty_space(self, tableaux):
        """Return empty space where a new number can be put in.
        """
        prev_pos = self.size

        for i, row in enumerate(tableaux):
            for j, pos in enumerate(row):
                if pos:
                    # ++
                    #
                    # go to the next col
                    continue
                else:
                    if j >= prev_pos:
                        # ■□
                        # ■□
                        #
                        # +-
                        # +-
                        pass
                    else:
                        # ■■
                        # ■□
                        #
                        # ++
                        # +-

                        prev_pos = j
                        yield (i,j)
                    # go to the next row
                    # go down the row
                    break

    def __repr__(self):
        buff = []
        for diagram in self:
            buff.append(mathformat.pprint_tableaux(diagram))

        return "\n".join(buff)

    __str__ = __repr__

    def set_tableaux(self):
        """Set tableaux
        """
        for diagram in self.generate():
            self.append(Young(diagram))


    def generate(self, tableaux = [], index = 1):
        """Generate tableaux
        """

        if not tableaux:
            tableaux = self.tableaux_initializer()

        for candidate in self.get_empty_space(tableaux):

            col, row = candidate

            # set the number
            tableaux[col][row] = index

            if index == self.size:
                yield copy.deepcopy(tableaux)

            else:
                for diagram in self.generate(tableaux[:], index +1):
                    yield diagram

            tableaux[col][row] = None



class YoungTableaux(dict, iterator.Iterator):
    """This class holds a sequence of Tableaux
    """
    __slots__ = (
            'total',
            'square',
            'seq_of_partition',
            )

    #total = 0
    #square = 0

    def __init__(self, seq_of_partition):
        self.seq_of_partition = seq_of_partition[:]
        self.total  = 0
        self.square = 0

    def __iter__(self):
        # NOTE
        # __iter__ just returns dictionary's values without sorting,
        # so the order is not always the same.
        return itertools.chain(*self.values())

    def set_tableaux(self):
        """Set the tableaux
        """
        if self:
            self.clear()

        for part in self.seq_of_partition:
            try:
                self[part] = _cache[part]
            except KeyError, e:
                yd = YoungGenerator(part)
                yd.set_tableaux()
                _cache[part] = yd
                self[part]   = yd

    def size(self):
        """Return the total number of the diagram
        """

        # TODO
        # もう少しましな名前にする

        if self.total:
            return self.total
        else:
            total = 0

        for k in self.iterkeys():
            total += len(self[k])

        self.total = total
        return total

    __len__ = size

    def up(self):
        """Up operation of diagram"""
        seq = [tuple(dia) for dia in func_chain(_up, self.seq_of_partition)]
        uniq = sets.Set(seq)

        # XXX
        # sort する場合
        #self.seq_of_partition = tuple(map(partition.Partition, uniq))
        self.seq_of_partition = map(partition.Partition, uniq)
        self.seq_of_partition.sort()
        self.seq_of_partition.reverse()
        #self.seq_of_partition = tuple(self.seq_of_partition)

        # Since we changed the original seq_of_partition,
        # we need to renew tableaux
        self.set_tableaux()

    def down(self):
        """Down operation of diagram"""
        seq = [tuple(dia) for dia in func_chain(_down, self.seq_of_partition)]
        uniq = sets.Set(seq)

        # XXX
        # sort する場合
        #self.seq_of_partition = tuple(map(partition.Partition, uniq))
        self.seq_of_partition = map(partition.Partition, uniq)
        self.seq_of_partition.sort()
        self.seq_of_partition.reverse()
        #self.seq_of_partition = tuple(self.seq_of_partition)

        # Since we changed the original seq_of_partition,
        # we need to renew tableaux
        self.set_tableaux()

    def report(self):
        """Display the statistical info about the tableaux
        """

        """\
        # sample
        partition             number  square
        ----------------------------------------
        (4)                        1       1
        (3,1)                      3       9
        (2,2)                      2       4
        (2,1,1)                    3       9
        (1,1,1,1)                  1       1
        ----------------------------------------
        total                     10      24
        """

        seq_of_part = self.keys()
        seq_of_part.sort()
        seq_of_part.reverse()

        print '%-20s%8s%8s'%('partition', 'number', 'square')
        print "-" * 40

        # initialize
        self.square = self.total = 0
        total = squares = 0

        for part in seq_of_part:
            num       = len(self[part])
            total    += num
            square    = num * num
            squares  += square
            print '%-20s%8d%8d'%(mathformat.pprint_partition(part).ljust(20), num, square)

        self.square = squares
        self.total  = total
        print "-" * 40
        print '%-20s%8d%8d'%('total', self.total, self.square)

    def get_partition(self):
        """Return all partitions of the tableaux
        """
        return self.seq_of_partition

    def display(self):
        """Display all the tableaux
        """
        print self

    def __repr__(self):
        return self._formatter()

    __str__ = __repr__

    def _formatter(self):
        """Formatter for tableaux
        """
        buff = []

        parts = self.keys()
        parts.sort()
        parts.reverse()

        for part in parts:
            buff.append(`part`)

            for tableaux in  self[part]:

                buff.append(mathformat.pprint_tableaux(tableaux))

        return "\n".join(buff)

def _up(shape):
    """Up operation of diagram"""
    # (2,1,1)
    # -->
    # (3,1,1)
    # (2,2,1)
    # (2,1,1,1)
    shape = list(shape)
    size_of_prev_line = shape[0] + 1

    for index, size_of_cur_line in enumerate(shape):
        if size_of_cur_line < size_of_prev_line:
            new_shape = shape[:]
            new_shape[index] = shape[index] + 1
            yield new_shape
        size_of_prev_line = size_of_cur_line
    else:
        shape.append(1)
        yield shape

def _down(shape):
    """Down operation of diagram"""
    # (2,1,1)
    # -->
    # (1,1,1)
    # (2,1)
    shape = list(shape)
    size_of_prev_line = 0
    shape.reverse()

    for index, size_of_cur_line in enumerate(shape):
        #if size_of_cur_line < size_of_prev_line:
        if size_of_prev_line < size_of_cur_line :
            new_shape = shape[:]
            num = shape[index] -1
            if num == 0:
                new_shape.pop(index)
            else:
                new_shape[index] = num
            #new[index-1] = shape[index-1] - 1
            new_shape.reverse()
            yield new_shape
        size_of_prev_line = size_of_cur_line

def func_chain(func, iterable):
    """chain-like function"""
    # similar to itertools.chain
    for shape in iterable:
        for element in func(shape):
            yield element

def ispartition(iterable):
    """Tests whether or not iterable is a valid partition
    """
    for i in range(len(iterable)):
        try:
            if not isinstance(iterable[i], int):
                raise error.PartitionFormatError, ("each element must be integer", iterable)
            if not iterable[i] >= iterable[i+1]:
                raise error.PartitionFormatError, ("partition is not in order", iterable)
        except IndexError, e:
            break

    return True

def young_of_sequence(shape):
    """Return all Young diagram of the given shape
    """
    # partition based

    # check if the shape is a valid partition.
    ispartition(shape)
    #shapes = (shape,)
    shapes = [shape]
    yt = YoungTableaux(shapes)
    yt.set_tableaux()
    return yt

def young_of_number(number):
    """Return all Young diagram of partition(N)
    """
    # number based

    shapes = partition.partition(number)
    yt = YoungTableaux(shapes)
    yt.set_tableaux()
    return yt

def young(*arg):
    """Return all the standard Young tableaux.

    if arg is a number,
    return all the tableaux of partition(the size of a partition is the number)

    if arg is a partition,
    return all the tableaux of the partition.
    """
    # XXX
    # there're three ways to call young().
    # [1]
    # young(3)       # arg = (number,)
    # [2]
    # young((2,1,1)) # arg = (partition,)
    # [3]
    # young(2,1,1)   # arbitrary number of arguments
    #
    # I think the third one is much more intuitive than the second one.

    if len(arg) == 1:
        arg = arg[0]

    if isinstance(arg, int):
        return young_of_number(arg)
    elif isinstance(arg, list) or isinstance(arg, tuple):
        return young_of_sequence(arg)
    else:
        raise TypeError, "argument must be integer or partition"

