# coding: utf-8
from numpy import concatenate, real, imag
from matplotlib import rcParams
from matplotlib.axes import Axes
from matplotlib.path import Path
import matplotlib.axis as maxis
from matplotlib.transforms import Affine2D, Transform, BboxTransformTo
from matplotlib.projections import register_projection
import matplotlib.spines as mspines

from matplotlib.artist import allow_rasterization
from matplotlib.patches import Circle
import matplotlib.ticker as mticker
import matplotlib.transforms as mtransforms
GRIDLINE_INTERPOLATION_STEPS = 180

class MyXTick(maxis.XTick):
    def update_position(self, loc, rng=(0.,1.)):
        maxis.XTick.update_position(self, loc)
        if self.gridOn: self.gridline.set_ydata(((rng[0]+10)/20.,(rng[1]+10)/20.))

class MyYTick(maxis.YTick):
    def update_position(self, loc, rng=(0.,1.)):
        maxis.YTick.update_position(self, loc)
        if self.gridOn: self.gridline.set_xdata((rng[0]/10.,rng[1]/10.))

class SmithAxis:
    def __init__(self):
        self.majorRngs = []
        self.minorRngs = []
    def set_ticks(self, ticks, minor=False):
        """
        Set the locations of the tick marks from sequence ticks

        ACCEPTS: sequence of floats
        """
        ### XXX if the user changes units, the information will be lost here
        x, y = [], []
        for xi, ybi, yei in ticks:
            x.append(xi)
            y.append((ybi,yei))
        ticks = self.convert_units(x)
        if len(ticks) > 1:
            xleft, xright = self.get_view_interval()
            if xright > xleft:
                self.set_view_interval(min(ticks), max(ticks))
            else:
                self.set_view_interval(max(ticks), min(ticks))
        if minor:
            self.set_minor_locator(mticker.FixedLocator(ticks))
            self.minorRngs = y
            return self.get_minor_ticks(len(ticks))
        else:
            self.set_major_locator( mticker.FixedLocator(ticks) )
            self.majorRngs = y
            return self.get_major_ticks(len(ticks))
    def iter_ticks(self):
        """
        Iterate through all of the major and minor ticks.
        """
        majorLocs = self.major.locator()
        majorTicks = self.get_major_ticks(len(majorLocs))
        self.major.formatter.set_locs(majorLocs)
        majorLabels = [self.major.formatter(val, i) for i, val in enumerate(majorLocs)]
        
        minorLocs = self.minor.locator()
        minorTicks = self.get_minor_ticks(len(minorLocs))
        self.minor.formatter.set_locs(minorLocs)
        minorLabels = [self.minor.formatter(val, i) for i, val in enumerate(minorLocs)]
        
        major_minor = [
            (majorTicks, majorLocs, self.majorRngs, majorLabels),
            (minorTicks, minorLocs, self.minorRngs, minorLabels)]

        for group in major_minor:
            for tick in zip(*group):
                yield tick
    @allow_rasterization
    def draw(self, renderer, *args, **kwargs):
        'Draw the axis lines, grid lines, tick lines and labels'
        ticklabelBoxes = []
        ticklabelBoxes2 = []

        if not self.get_visible(): return
        renderer.open_group(__name__)
        interval = self.get_view_interval()
        for tick, loc, rng, label in self.iter_ticks():
            if tick is None: continue
            if not mtransforms.interval_contains(interval, loc): continue
            tick.update_position(loc, rng=rng)
            tick.set_label1(label)
            tick.set_label2(label)
            tick.draw(renderer)
            if tick.label1On and tick.label1.get_visible():
                extent = tick.label1.get_window_extent(renderer)
                ticklabelBoxes.append(extent)
            if tick.label2On and tick.label2.get_visible():
                extent = tick.label2.get_window_extent(renderer)
                ticklabelBoxes2.append(extent)

        # scale up the axis label box to also find the neighbors, not
        # just the tick labels that actually overlap note we need a
        # *copy* of the axis label box because we don't wan't to scale
        # the actual bbox

        self._update_label_position(ticklabelBoxes, ticklabelBoxes2)

        self.label.draw(renderer)

        self._update_offset_text_position(ticklabelBoxes, ticklabelBoxes2)
        self.offsetText.set_text( self.major.formatter.get_offset() )
        self.offsetText.draw(renderer)
        
        renderer.close_group(__name__)

class MyXAxis(SmithAxis,maxis.XAxis):
    def __init__(self, axes, pickradius=15):
        maxis.XAxis.__init__(self, axes, pickradius=pickradius)
        SmithAxis.__init__(self)
    def _get_tick(self, major):
        return MyXTick(self.axes, 0, '', major=major)

class MyYAxis(SmithAxis,maxis.YAxis):
    def __init__(self, axes, pickradius=15):
        maxis.YAxis.__init__(self, axes, pickradius=pickradius)
        SmithAxis.__init__(self)
    def _get_tick(self, major):
        return MyYTick(self.axes, 0, '', major=major)

class SmithAxes(Axes):
    name = 'smith'
    def __init__(self, *args, **kwargs):
        Axes.__init__(self, *args, **kwargs)
        self.set_aspect('equal', adjustable='box', anchor='C')
    def cla(self):
        Axes.cla(self)
        self.grid(rcParams['polaraxes.grid'])
        self.set_xticks([(.1,-1.,1.),(.2,-2.,2.),(.3,-1.,1.),
                         (.4,-2.,2.),(.5,-1.,1.),(.6,-2.,2.),(.7,-1.,1.),
                         (.8,-2.,2.),(.9,-1.,1.),(1.,-5.,5.),(1.2,-2.,2.),
                         (1.4,-2.,2.),(1.6,-2.,2.),(1.8,-2.,2.),(2.,-5.,5.),
                         (3.,-5.,5.),(4.,-5.,5.),(5.,-10.,10.),
                         (10.,-20.,20.),(20.,-50.,50.),(10,-100.,100.)])
        self.set_yticks([
                         (-50.,0.,100.),(-20.,0.,50.),(-10.,0.,20.),
                         (-5.,0.,10.),(-4.,0.,5.),(-3.,0.,5.),(-2.,0.,5.),(-1.8,0.,2.),
                         (-1.6,0.,2.),(-1.4,0.,2.),(-1.2,0.,2.),(-1.,0.,5.),(-.9,0.,1.),
                         (-.8,0.,2.),(-.7,0.,1.),(-.6,0.,2.),(-.5,0.,1.),(-.4,0.,2.),
                         (-.3,0.,1.),(-.2,0.,2.),(-0.15,0.,.2),(-.1,0.,1.),(-.05,0.,.2),
                         (0.,0.,100.),
                         (.05,0.,.2),(.1,0.,1.),(0.15,0.,.2),(.2,0.,2.),(.3,0.,1.),
                         (.4,0.,2.),(.5,0.,1.),(.6,0.,2.),(.7,0.,1.),(.8,0.,2.),
                         (.9,0.,1.),(1.,0.,5.),(1.2,0.,2.),(1.4,0.,2.),(1.6,0.,2.),
                         (1.8,0.,2.),(2.,0.,5.),(3.,0.,5.),(4.,0.,5.),(5.,0.,10.),
                         (10.,0.,20.),(20.,0.,50.),(50.,0.,100.)
                         ])
        self.xaxis.set_ticklabels([])
        self.yaxis.set_ticklabels([])
        self.xaxis.set_ticks_position('none')
        self.yaxis.set_ticks_position('none')
    def _init_axis(self):
        self.xaxis = MyXAxis(self)
        self.yaxis = MyYAxis(self)
    def _set_lim_and_transforms(self):
        '''
        データ、テキスト、グリッドに対するすべての射影をセットアップする
        プロットが生成されたときに一度だけ呼び出される
        '''
        #　以下に３つの重要な座標空間を挙げる
        #　　１．Data space：データ自身の空間
        #　　２．Axes space：すべてのプロット領域を変換する単位矩形(0,0)-(1,1)
        #　　３．Display space：結果画像の座標系で、しばしばpixelかdpi/inchで表す
        #　最初の２つの変換の目的はdata spaceからaxes spaceを得ること。
        #　ウィンドウのリサイズやdpiの変換などの際にnon-affine部は再計算が必要ないので、
        #　non-affine部とaffine部に分割する。
        
        #　smith-z平面からxy平面(-1,-1)-(1,1)への変換
        self.transProjection = self.SmithTransform()
        #　xy平面を(0.0,0.0)-(1.0,1.0)のaxes平面に線形変換する
        self.transAffine = Affine2D().scale(.5,.5).translate(.5,.5)
        #　axes平面からディスプレイ平面への変換
        self.transAxes = BboxTransformTo(self.bbox)
        #　データからディスプレイ座標系への道のりとなる３つの変換を結合する。
        # '+'演算子を使うことでこれらの変換が順番に適用される。
        self.transData = \
            self.transProjection + \
            self.transAffine + \
            self.transAxes
        #　グリッド線の座標変換 Xの値はそのまま、Yの値を-10〜10に変換
        self._xaxis_transform = \
            Affine2D().scale(1.,20.).translate(0.,-10.) + \
            self.transData
        self._r_label1_position = Affine2D().translate(0., .5)
        self._xaxis_text1_transform = (
            self._r_label1_position +
            self._xaxis_transform)
#        self._r_label2_position = Affine2D().translate(0.0, 1.0 / 1.1)
#        self._xaxis_text2_transform = (
#            self._r_label2_position +
#            self._xaxis_transform)
        self._rpad = 0.05
        #　グリッド線の座標変換（yの範囲0〜10へ拡張）
        self._yaxis_transform = \
            Affine2D().scale(10.,1.).translate(0.,0.) + \
            self.transData
        self._x_label1_position = Affine2D().translate(0., 0.)
        self._yaxis_text1_transform = (
            self._x_label1_position +
            #Affine2D().scale(1.0 / 360.0, 1.0) +
            self._yaxis_transform
            )
        
    def get_xaxis_transform(self,which='grid'):
        assert which in ['tick1','tick2','grid']
        return self._xaxis_transform
    
    def get_xaxis_text1_transform(self, pad):
        return self._xaxis_text1_transform, 'top', 'center'
    
    def get_yaxis_transform(self,which='grid'):
        assert which in ['tick1','tick2','grid']
        return self._yaxis_transform
    
    def get_yaxis_text1_transform(self, pad):
        return self._yaxis_text1_transform, 'center', 'right'

    def get_data_ratio(self):
        return 1.0
    
    def set_xlim(self, *args, **kargs):
        self.viewLim.intervalx = (0., 100.)

    def set_ylim(self, *args, **kargs):
        self.viewLim.intervaly = (-100, 100)
        
    def _gen_axes_patch(self):
        return Circle((0.5, 0.5), 0.5)

    def _gen_axes_spines(self):
        return {'polar':mspines.Spine.circular_spine(self,
                                                     (0.5, 0.5), 0.5)}

    class SmithTransform(Transform):
        input_dims = 2
        output_dims = 2
        is_separable = False
        def transform(self,rx): # Nx2 array
            z = rx[:,0:1] + 1j * rx[:,1:2]
            g = (z - 1.) / (z + 1.)
            return concatenate((real(g), imag(g)), 1)
        def transform_path(self, path):
            ipath = path.interpolated(path._interpolation_steps)
            return Path(self.transform(ipath.vertices), ipath.codes)
        def inverted(self):
            return SmithAxes.InvertedSmithTransform()
        
    class InvertedSmithTransform(Transform):
        input_dims = 2
        output_dims = 2
        is_separable = False
        def transform(self, xy): # Nx2 array
            g = xy[:,0:1] + 1j * xy[:,1:2]
            z = (1. + g) / (1. - g)
            return concatenate((real(z), imag(z)), 1)
        def inverted(self):
            return SmithAxes.SmithTransform()

register_projection(SmithAxes)

if __name__ == '__main__':
    from pylab import subplot, grid, show
    ax = subplot(111, projection="smith")
    ##p = plot([0., 100., 0., 0.], [0., 0., -1., 1.], "o-")
    ##x, y = rand(1000,1), rand(1000,1) # データを生成する
    ##line = matplotlib.lines.Line2D(x,y) # オブジェクトを生成する
    ##line._transform = ax.transData # 射影を設定する
    ##ax.lines.append(line) # lineオブジェクトへの参照を維持する
    #axis('off')
    grid(True)
    #setp(gca(), 'xticklabels', [])
    #setp(gca(), 'yticklabels', [])
    show()
