# coding: utf-8
from term import Terminal, event, connect, rule, Rule, Cell
from tkext import Entry, Button, LatchedButton , MultipleEntries
from inst import Logic, Panel, Embed
from pylafii.utils.tk import TableBuilder
from numpy import *
import utils.mpl_tk
import Tkinter
import weakref
import uuid
import re

class SimplePlot(Logic):
    sigin = Cell()
    @rule
    def through(self,sigin):
        return sigin
    class Plot(utils.mpl_tk.BasePlot,object):
        @event()
        def sigin(self):
            if self.sigin is None: return
            args,kw = utils.parse_port_args(self.sigin)
            self.plot(*args,**kw)
        def connect(self,logic):
            # mplext.BasePlot では self が先なので chef と一緒に使うと違和感があった
            for src,dest in [('through','sigin')]: connect((logic,src),(self,dest))
            #for src,dest in [('sigin','sigin')]: connect((logic,src),(self,dest))
        def plot(self,*args,**kwargs):
            self.clear()
            try:
                self.ax.plot(*args,**kwargs)
            except ValueError:
                return
            self.canvas.show()
            self.update_idletasks() # スレッド処理でButtonに対応するにはupdate_idletasksを使う。updateではButtonがロックすることがある。

class Format(Logic):
    def eval_format(self,args,format):
        variable = re.compile('\$[a-zA-Z0-9_\-]+')
        rst = []
        # extract expressions
        exprs = []
        square_bracket = False
        buffer = ''
        for expr in format.split(','):
            # for array index[i,j,k]
            if square_bracket:
                buffer = buffer + ',' + expr
                if not expr.find(']') == -1:
                    exprs.append(buffer)
                    buffer = ''
                    square_bracket = False
                continue
            if not expr.find('[') == -1:
                buffer = buffer + expr
                square_bracket = True
                continue
            exprs.append(expr)
        for expr in exprs:
            if expr.find('"') == -1 and expr.find('\'') == -1:
                #args = ''
                #for arg in variable.findall(expr):
                #    args = args + ',%s' % arg[1:]
                #args = args[1:]
                #print args
                rst.append((args.split(','),eval('lambda %s:%s' % (args,expr.replace('$','')))))
            else:
                rst.append(expr.replace('"','').replace('\'',''))
        return rst

class Format1(Format):
    a = Cell()
    format = Cell('$a')
    @rule(ignore='format')
    def sigout(self,a):
        return tuple([func(a) for args,func in self.eval_format('a',self.format)])

class Format2(Format):
    a = Cell()
    b = Cell()
    format = Cell('$a,$b')
    @rule(ignore='format')
    def sigout(self,a,b):
        return tuple([func(self.a,self.b) for args,func in self.eval_format('a,b',self.format)])

class Format3(Format):
    a = Cell()
    b = Cell()
    c = Cell()
    format = Cell('$a,$b,$c')
    @rule(ignore='format')
    def sigout(self,a,b,c):
        return tuple([func(self.a,self.b,self.c) for args,func in self.eval_format('a,b,c',self.format)])

class Format4(Format):
    a = Cell()
    b = Cell()
    c = Cell()
    d = Cell()
    format = Cell('$a,$b,$c,$d')
    @rule(ignore='format')
    def sigout(self,a,b,c,d):
        fmt = []
        for expr in self.eval_format('a,b,c,d',self.format):
            if isinstance(expr,tuple):
                func = expr[1]
                fmt.append(func(self.a,self.b,self.c,self.d))
            else:
                fmt.append(expr)
        return tuple(fmt)

class FormatXY(Logic):
    x = Cell()
    y = Cell()
    format = Cell('$x,$y')
    @rule(ignore='format')
    def sigout(self,a,b):
        return tuple([func(self.x,self.y) for args,func in self.eval_format('x,y',self.format)])

class BasePlot(Logic):
    value = Terminal()
    xlim = Terminal()
    ylim = Terminal()
    base_plot = Terminal()
    @rule('value')
    def result(self): return self.value
    @property
    def plot(self): return self.base_plot()
    class Plot(utils.mpl_tk.BasePlot,object):
        xlim = Terminal()
        ylim = Terminal()
#        value = Terminal()
#        @rule
#        def value_rule(self,value):
#            if value is None: return
#            args, kw = utils.parse_port_args(value)
#            # print [o.owner for o in self.__class__.__dict__['value'].attr(self).node.observers_]
#            self._plot(*args,**kw)
        @event()
        def value(self):
            if self.value is None: return
            args,kw = utils.parse_port_args(self.value)
            self._plot(*args,**kw)
        def connect(self,logic):
            for name in ['xlim','ylim']: connect((self,name),(logic,name))
            for src,dest in [('value','result')]: connect((self,src),(logic,dest))
            logic.base_plot = weakref.ref(self)
        def __init__(self,master=None,figsize=(5,4),dpi=75,align=111,projection='rectilinear',cnf={},**kw):
            utils.mpl_tk.BasePlot.__init__(self,master,figsize,dpi,align,projection,cnf,**kw)
            self._update_xlim(None)
            self._update_ylim(None)
            self._connect()
        def _connect(self):
            #
            # プロット範囲更新イベントの設定
            #
            self._callbacks = cb = {}
            cb['cid_xlim']  = self.ax.callbacks.connect('xlim_changed',self._update_xlim)
            cb['cid_ylim']  = self.ax.callbacks.connect('ylim_changed',self._update_ylim)
        def _update_xlim(self,e): # ax の xlim が更新されたときに control の GUI を更新する
            self._update_xlim_called = True
            self.xlim = self.ax.get_xlim()
            del self._update_xlim_called
        def _update_ylim(self,e): # ax の xlim が更新されたときに control の GUI を更新する
            self._update_ylim_called = True
            self.ylim = self.ax.get_ylim()
            del self._update_ylim_called
        def _disconnect(self):
            self.ax.callbacks.disconnect(self._callbacks['cid_xlim'])
            self.ax.callbacks.disconnect(self._callbacks['cid_ylim'])
        def destroy(self):
            self._disconnect()
            utils.mpl_tk.BasePlot.destroy(self)
        def _plot(self,*args,**kwargs):
            try:
                self.clear()
                self.ax.plot(*args,**kwargs)
            except ValueError:
                return
            self.canvas.show()
            self.update_idletasks() # スレッド処理でButtonに対応するにはupdate_idletasksを使う。updateではButtonがロックすることがある。
        def resize(self,dpi=None,figsize=None):
            if (dpi is None or figsize is None): return
            subplotpars = self.ax.figure.subplotpars
            left   = subplotpars.left
            right  = subplotpars.right
            bottom = subplotpars.bottom
            top    = subplotpars.top
            title  = self.ax.get_title()
            xlabel = self.ax.get_xlabel()
            ylabel = self.ax.get_ylabel()
            self._disconnect()
            utils.mpl_tk.BasePlot.resize(self,dpi=dpi,figsize=figsize)
            self._connect()
            self.ax.set_ylabel(ylabel)
            self.ax.set_xlabel(xlabel)
            self.ax.set_title(title)
            self.ax.figure.subplots_adjust(left = left, right = right, bottom = bottom, top = top)
            self.value = self.value
    class Control(Panel):
        base_plot = Terminal()
        @event([0.,1.])
        def xlim(self): self._changed_xlim()
        @event([0.,1.])
        def ylim(self): self._changed_ylim()
        @event(True)
        def autoscalex_on(self): self._changed_autoscalex()
        @event(True)
        def autoscaley_on(self): self._changed_autoscaley()
        def connect(self,logic):
            plot = logic.base_plot() # plot pane を先に connect しなければならない
            autoscalex_on = plot.ax.get_autoscalex_on() # xlim を connect すると変化してしまうのでボタンの状態をバックアップ
            autoscaley_on = plot.ax.get_autoscaley_on()
            for name in ['base_plot','xlim','ylim']: connect((logic,name),(self,name))
            self.autoscalex_on = autoscalex_on
            self.autoscaley_on = autoscaley_on
        def __init__(self,master=None,cnf={},**kw):
            Panel.__init__(self,master,cnf,**kw)
            #
            Tkinter.Label(self,text='X Axis').grid(row=1,column=0)
            MultipleEntries(self,column=2,name='xlim').grid(row=1,column=1)
            LatchedButton(self,text='auto',name='autoscalex_on',positive=False).grid(row=1,column=3)
            #
            Tkinter.Label(self,text='Y Axis').grid(row=2,column=0)
            MultipleEntries(self,column=2,name='ylim').grid(row=2,column=1)
            LatchedButton(self,text='auto',name='autoscaley_on',positive=False).grid(row=2,column=3)
            for name in ['xlim','ylim','autoscalex_on','autoscaley_on']:
                connect((self,name),(self.children[name],'value'))
        def _changed_xlim(self):
            '''
            xlim エントリーが変更されたらば、
            autoscalex を disable し、
            autoscalex ボタンを enable し、
            xlim を更新し、
            プロットを再描画する
            '''
            if self.base_plot is None: return # もしまだ PlotWidget が接続されていなければ無視する
            plot = self.base_plot()
            try: plot._update_xlim_called # plot 内部で使用しているコールバックが逆流しないように
            except AttributeError:
                plot.ax.set_xlim(*self.xlim) # xlim をアップデート
                self.autoscalex_on = plot.ax.get_autoscalex_on() # ボタンの状態をアップデート
                plot.canvas.show() # キャンバスを再表示
        def _changed_ylim(self):
            if self.base_plot is None: return
            plot = self.base_plot()
            try: plot._update_ylim_called
            except AttributeError:
                plot.ax.set_ylim(*self.ylim)
                self.autoscaley_on = plot.ax.get_autoscaley_on()
                plot.canvas.show()
        def _changed_autoscalex(self):
            '''
            autoscalexボタンを押したらば、
            autoscalexをenableし、
            autoscalexボタンをdisableし、
            再プロットする
            '''
            if self.base_plot is None: return
            plot = self.base_plot()
            plot.ax.set_autoscalex_on(self.autoscalex_on)
            BasePlot.Plot.value.rule(plot)
        def _changed_autoscaley(self):
            if self.base_plot is None: return
            plot = self.base_plot()
            plot.ax.set_autoscaley_on(self.autoscaley_on)
            BasePlot.Plot.value.rule(plot)
    class Config:
        class Decolation(Panel):
            base_plot = Terminal()
            LAYOUT = [
                      (Tkinter.Label,{'text':'PADL'}),(Entry,{'name':'padl'}),
                      (Tkinter.Label,{'text':'PADR'}),(Entry,{'name':'padr'}),
                      (Tkinter.Label,{'text':'PADB'}),(Entry,{'name':'padb'}),
                      (Tkinter.Label,{'text':'PADT'}),(Entry,{'name':'padt'}),
                      (Button,{'name':'reset_pad','text':'reset'},{'columnspan':2}),
                      (Tkinter.Label,{'text':'TITLE'}),(Entry,{'name':'title'}),
                      (Tkinter.Label,{'text':'XLABEL'}),(Entry,{'name':'xlabel'}),
                      (Tkinter.Label,{'text':'YLABEL'}),(Entry,{'name':'ylabel'}),
                      (Button,{'name':'reset_label','text':'reset'},{'columnspan':2}),
                      (Tkinter.Label,{'text':'DPI'}),(Entry,{'name':'dpi'}),
                      (Tkinter.Label,{'text':'FIGSIZE'}),(MultipleEntries,{'name':'figsize'}),
                      ]
            def __init__(self,master=None,cnf={},**kw):
                Panel.__init__(self,master,cnf,**kw)
                self.__initializing = True
            def connect(self,logic):
                connect((logic,'base_plot'),(self,'base_plot'))
                if self.base_plot is None: return
                plot = self.base_plot()
                subplotpars = plot.ax.figure.subplotpars
                self._config = {
                                'left'   : subplotpars.left,
                                'right'  : subplotpars.right,
                                'bottom' : subplotpars.bottom,
                                'top'    : subplotpars.top,
                                'title'  : plot.ax.get_title(),
                                'xlabel' : plot.ax.get_xlabel(),
                                'ylabel' : plot.ax.get_ylabel(),
                                'figsize' : plot.figure.get_size_inches(),
                                'dpi'     : plot.figure.get_dpi(),
                                }
                self.children['padl'].value = subplotpars.left
                self.children['padr'].value = subplotpars.right
                self.children['padb'].value = subplotpars.bottom
                self.children['padt'].value = subplotpars.top
                self.children['title'].value = plot.ax.get_title()
                self.children['xlabel'].value = plot.ax.get_xlabel()
                self.children['ylabel'].value = plot.ax.get_ylabel()
                self.children['figsize'].value = self._config['figsize']
                self.children['dpi'].value = self._config['dpi']
                for name in ['padl','padr','padb','padt','reset_pad','reset_label','title','xlabel','ylabel','figsize','dpi']:
                    connect((self.children[name],'value'),(self,name))
                self.__initializing = False
            @event()
            def padl(self): self._update_adjust()
            @event()
            def padr(self): self._update_adjust()
            @event()
            def padb(self): self._update_adjust()
            @event()
            def padt(self): self._update_adjust()
            @event()
            def reset_pad(self): self._reset_pad()
            @event()
            def reset_label(self): self._reset_label()
            @event()
            def figsize(self): self._resize()
            @event()
            def dpi(self): self._resize()
            @event()
            def title(self):
                if self.__initializing: return
                plot = self.base_plot()
                plot.ax.set_title(self.title)
                plot.canvas.show()
            @event()
            def xlabel(self):
                if self.__initializing: return
                plot = self.base_plot()
                plot.ax.set_xlabel(self.xlabel)
                plot.canvas.show()
            @event()
            def ylabel(self):
                if self.__initializing: return
                plot = self.base_plot()
                plot.ax.set_ylabel(self.ylabel)
                plot.canvas.show()
            def _resize(self):
                if self.__initializing: return
                figsize = self.figsize
                dpi     = self.dpi
                self.base_plot().resize(dpi,figsize)
            def _update_adjust(self):
                if self.__initializing: return
                plot = self.base_plot()
                plot.ax.figure.subplots_adjust(left   = self.padl,
                                               right  = self.padr,
                                               bottom = self.padb,
                                               top    = self.padt)
                plot.canvas.show()
            def _reset_pad(self):
                if self.__initializing: return
                self.padl = self._config['left']
                self.padr = self._config['right']
                self.padb = self._config['bottom']
                self.padt = self._config['top']
            def _reset_label(self):
                if self.__initializing: return
                self.title = self._config['title']
                self.xlabel = self._config['xlabel']
                self.ylabel = self._config['ylabel']
                
class BasePlotPanel(Panel):
    def layout(self):
        Embed(self,BasePlot,name='result').pack()
        
class ScalarTable(Logic):
    value = Terminal()
    format = Terminal('-')
    @rule('value')
    def result(self):
        if self.value is None: return (array([0.,1.]),array([0.,1.]))
        return (self.value[:,0],self.value[:,1:],self.format)
        
class MatrixPlot(BasePlot):
    result = Terminal()
    def __init__(self,master=None,name=None):
        BasePlot.__init__(self,master,name)
        o = ScalarTable(self,name='scalartable')
        for name in ['value','result']: connect((self,name),(o,name))
        
class MatrixPlotPanel(Panel):
    def layout(self):
        Embed(self,MatrixPlot,name='result').pack()
        
class PlotPanel(Panel):
    Plot = BasePlot
    def layout(self):
        Embed(self,self.Plot,name='result').pack()

class DefaultPlotPanel(Panel): pass
#pylafii_config['equipment_option'][DefaultPlotPanel] = {'panel_klass':BasePlot}

#                               'sequence': ('azimuth','farfield','theory'),
#                               'geometry': 'pack',
#                               'geometry_option': {'default': {'side': Tkinter.LEFT},
#                                                   'theory': {'side':Tkinter.LEFT}},
#                               'klass': {'theory':BasePlot},
#                               'format': {'theory':'$theory_ang / pi * 180,20 * log10(abs($theory)),"g-"'},
#                               'set': {'theory': {
#                                                  'xlabel':'Azimuth Angle[deg]',
#                                                  'ylabel':'Relative Receiving Power [dB]',
#                                                  'xlim':[0.,180,],
#                                                  'ylim':[-30.,0.],
#                                                  },},

class PlotPanelFactory(Panel):
#    def __init__(self,master=None,plot=(('result',BasePlot)),geometry=('pack',{'side':Tkinter.LEFT}),format={},cnf={},**kw):
    def __init__(self,
                 master=None,
                 sequence=('result',),
                 geometry_manager='pack',
                 geometry_option={'default':{'side':Tkinter.LEFT}},
                 klass={},
                 format={},
                 set={},
                 cnf={},
                 **kw
                 ):
        Panel.__init__(self,master,cnf,**kw)
        #
        # オプションのデフォルト処理
        #
        if not isinstance(sequence,tuple): # 単一のオプションが与えられたときに発生するエラーの回避
            sequence = (sequence,)
        for name in sequence: # klassオプションのデフォルト処理
            if name not in klass:
                klass[name] = BasePlot
        #
        if geometry_manager not in ['pack','grid']:
            print 'ERROR!!'
            exit()
        default = None
        if 'default' in geometry_option:
            default = geometry_option['default']
            del geometry_option['default']
        else:
            if geometry_manager == 'pack':
                default = {'side':Tkinter.LEFT}
            elif geometry_manager == 'grid':
                default = {}
        for name in sequence:
            if name not in geometry_option:
                geometry_option[name] = default
        #
        # 埋め込みプロッタインスタンスの生成
        #
        for name in sequence:
            o = Embed(self,klass[name],name=name)
            getattr(o,geometry_manager)(**geometry_option[name])
        #
        # format オプション
        #
        for name in sequence:
            if name in format:
                if format[name] is not '':
                    emb = self.children[name]
                    klass = MPL_Parser.factory(format[name])
                    o = klass(emb.logic,name='__parser__')
                    connect((o,'_mpl_parser_result'),(emb.logic,'value'))
        #
        # set オプション
        #
        for name in sequence:
            if name in set:
                for k,v in set[name].iteritems():
                    getattr(self.children[name].panel['plot'].ax,'set_%s' % k)(v)


#        for item in plot:
#            option,geom_option = {},{}
#            if len(item) == 2:
#                name, klass = item
#            elif len(item) == 3:
#                name,klass,option = item
#            elif len(item) == 4:
#                name,klass,option,geom_option = item
#            else:
#                continue
#            o = Embed(self,klass,name=name,**option)
#            manager, option = geometry
#            for k,v in option.iteritems():
#                if not k in geom_option:
#                    geom_option[k] = v
#            if manager == 'pack':
#                o.pack(**geom_option)
#            elif manager == 'grid':
#                o.grid(**geom_option)
#            if name in format:
#                if 'format' in format[name]:
#                    if not format[name]['format'] == '':
#                        emb = self.children[name]
#                        prs = MPL_Parser.factory(format[name]['format'])(emb.logic,name='__parser__')
#                        connect((prs,'result'),(emb.logic,'value'))
#                if 'set' in format[name]:
#                    for k,v in format[name]['set'].iteritems():
#                        getattr(self.children[name].panel['plot'].ax,'set_%s' % k)(v)
    def connect(self,logic):
        for k,emb in self.children.iteritems():
            if '__parser__' in emb.logic.children:
                prsr = emb.logic.children['__parser__']
                for tname in prsr.extract_tnames(prsr._format):
                    connect((logic,tname),(prsr,tname))
            else:
                for src,dest in [(k,'value')]:
                    if src in logic.__class__.__dict__:
                        connect((logic,src),(emb.logic,dest))
    @rule()
    def trig(self):
        for k,emb in self.children.iteritems():
            if '__parser__' in emb.logic.children:
                emb.logic.children['__parser__']._mpl_parser_result
    def _polling(self):
        for k,emb in self.children.iteritems():
            if '__parser__' in emb.logic.children:
                emb.logic.children['__parser__']._mpl_parser_result
        self.id = self.after(100,self._polling)
        
class MPL_Parser(Logic):
    @classmethod
    def extract_tnames(cls,format):
#        result = {}
#        for item in format.split(','):
#            if item.find('=') == -1 and item.find('"') == -1 and item.find('\'') == -1:
#                result[item] = None
#        return [k for k in result.iterkeys()]
        variable = re.compile('\$[a-zA-Z0-9_\-]+')
        names = {}
        for s in variable.findall(format): names[s] = None
        names = [s[1:] for s in names.iterkeys()]
        return names
    @classmethod
    def factory(cls,format):
        dct = {}
        for name in cls.extract_tnames(format):
            dct[name] = Terminal(None)
        dct['_format'] = format
        dct['_mpl_parser_result'] = rule(cls._mpl_parser_result)
        return type('MPL_Parser_%s' % uuid.uuid4(),(cls,),dct)
    def __init__(self,master=None,name=None):
        Logic.__init__(self,master,name)
        self._initializing = None
        self._mpl_parser_result # ターミナルの生成
        dep = []
        for name in self.__class__.extract_tnames(self._format):
            dep.append(name) # 依存関係更新の準備
            self.__dict__[name].node.getlog[self.__dict__['_mpl_parser_result']] = None # 更新フラグの調整
        self.__dict__['_mpl_parser_result'].term.dependency = tuple(dep) # 依存関係の更新
    #@rule
    def _mpl_parser_result(self):
        #
        # 関連している各種ルールを実行したあと、最後にルールを実行するためのしくみ
        # まだ実行中のルールがある場合にはパース処理を中止する
        #
        try: self._initializing
        except AttributeError: pass
        else:
            del self._initializing
            return
        # フォーマット文字列に含まれるすべてのターミナルについて
        #for item in self._format.split(','):
        for item in self.__class__.extract_tnames(self._format):
            #if item.find('=') == -1 and item.find('"') == -1 and item.find('\'') == -1:
                for attr in self.__dict__[item].node.observers_: # 各々のターミナルに関連したノードを監視するすべての属性子について
                    if isinstance(attr.term,Rule): # ルール型の属性子であり
                        if attr in attr.evaluating_rule: # 属性子に結びつけられたルールが現在実行中であれば、なにもせず脱出する
                            return
        #
        res = []
        #for item in self.__class__.extract_tnames(self._format):
        #for item in self._format.split(','):
            #if item.find('=') == -1 and item.find('"') == -1 and item.find('\'') == -1:
            #if item.find('$') is not -1:
            #    res.append(getattr(self,self.__class__.extract_tnames(item)[0]))
            #else:
            #    res.append(item.replace('"','').replace('\'',''))
        for expr in self.eval_format(self._format):
            if isinstance(expr,tuple):
                args = [getattr(self,name) for name in expr[0]]
                res.append(expr[1](*args))
            else:
                res.append(expr)
        return tuple(res)
    def eval_format(self,format):
        variable = re.compile('\$[a-zA-Z0-9_\-]+')
        rst = []
        # extract expressions
        exprs = []
        square_bracket = False
        buffer = ''
        for expr in format.split(','):
            # for array index[i,j,k]
            if square_bracket:
                buffer = buffer + ',' + expr
                if not expr.find(']') == -1:
                    exprs.append(buffer)
                    buffer = ''
                    square_bracket = False
                continue
            if not expr.find('[') == -1:
                buffer = buffer + expr
                square_bracket = True
                continue
            exprs.append(expr)
        for expr in exprs:
            if expr.find('"') == -1 and expr.find('\'') == -1:
                args = ''
                for arg in variable.findall(expr):
                    args = args + ',%s' % arg[1:]
                args = args[1:]
                rst.append((args.split(','),eval('lambda %s:%s' % (args,expr.replace('$','')))))
            else:
                rst.append(expr.replace('"','').replace('\'',''))
        return rst
    