## Create different plots based on user specifications in 'params_plot.py'

# Import required libraries
from subprocess import call
import numpy as np
import sys,os
import math
import glob
import matplotlib.pyplot as mplt

# class definition
class Plot:
  pass

  # Function to extract plot parameters
  def getPlotinfo(self, plot):

    # Kind of plot required: FFT/deltaP/xy/polyfit
    self.kind = plot['kind']
    print 'kind =', plt.kind

    # Data file for plotting
    #self.dataFile = plot['data']
    datas = plot['data']
    self.dataFile = []
    for idata in range(len(datas)):
      datas_tmp = glob.glob(datas[idata])
      for idata2 in range(len(datas_tmp)):
        self.dataFile.append(datas_tmp[idata2])

    self.dataFile.sort()
#    print 'datafile', self.dataFile    
    if len(self.dataFile) == 0:
      print 'unable to load datafile'
      sys.exit(2)
    # Load the data
    data = np.loadtxt(self.dataFile[0])

    # if data file contains only one line
    if data.size == data.shape[0]:
      data.resize(1,data.size)

    self.sRow = plot['row'][0] - 1
    self.eRow = plot['row'][1] 

    if self.kind != 'xy':
      # Last row of data. If '-1', lastrow = lastline of file
      if self.eRow == -1:
        self.eRow = len(data[:,0]) 

      # If s.Row is '-ve', starting row is calculated based on eRow
      if self.sRow < 0:
        self.sRow = self.eRow + self.sRow

    self.xCol = plot['col'][0] - 1
    self.yCol = np.array(plot['col'][1:]) - 1

    # Checking lattice to physical parameter
    try:
      self.convertPhy = plot['convertPhy']

      #If true, xfactor and yfactor values extracted, else default=1.0
      if self.convertPhy == True:
        self.xfac = plot['xfac']
        self.yfac = plot['yfac']
      else:  
        self.xfac = 1.
        self.yfac = 1.
    except:    
      self.xfac = 1.
      self.yfac = 1.

    #Check xlabel, else default = ''
    try:
      self.xlabel = plot['xlabel']
    except:
      self.xlabel = ''

    #Check ylabel, else default = ''
    try:  
      self.ylabel = plot['ylabel']
    except:
      self.ylabel = ''

    #Linewidth, default = 1
    try:
      self.lw = plot['lw']
    except:
      self.lw = 1.

    #Linestyle
    try:  
      self.ls = plot['ls']
    except:
      self.ls = 'r-'

    #Label of graph(to appear on legend)
    try:
      self.label = plot['label']
    except:
      self.label = self.dataFile[0]
    print 'label = ', self.label

    #Title of the plot, default = ''
    try:
      self.title = plot['title']
    except:
      self.title = ''

    # Check whether saveplot is available
    try:
      self.saveplot = plot['saveplot']
    except:
      self.saveplot = False

    if self.saveplot == True:
      # Location of the legend, default = position '0'
      try:
        self.legLoc = plot['legend_loc']
      except:
        self.legLoc = 0

      # If saveplot is true, ask for figure name to save
      try:
        self.figname = plot['figname']
      except:
        print 'figname not found so plot will not be saved'
#        sys.exit('ERROR: figure name required for saving plot')
        self.figname = ''

    # Plot on new window, default = False
    try:
      self.newplot = plot['newplot']
    except:
      self.newplot = False

    # Range of axes: xmin, xmax values
    try:
      self.xmin = plot['xmin']
      self.xmax = plot['xmax']
    except:
      self.xmin = '' 
      self.xmax = ''

    # Range of axes: ymin, ymax values
    try:
      self.ymin = plot['ymin']
      self.ymax = plot['ymax']
    except:
      self.ymin = ''
      self.ymax = ''

    return 


  # Extract x, y data from given datafile
  def getxy(self,iData, iCol):
      data = np.loadtxt(self.dataFile[iData])
      x = data[self.sRow:self.eRow,self.xCol]*self.xfac
      y = data[self.sRow:self.eRow,self.yCol[iCol]]*self.yfac
      return x, y

  # Simple x,y-plot with given parameters
  def plot2d(self, x, y):
      mplt.plot(x,y, self.ls, linewidth=self.lw, label=self.label)

      # If axes ranges are specified, use those values
      if (self.xmin != '' and self.xmax != ''):
        mplt.xlim(self.xmin, self.xmax)
      if (self.ymin != '' and self.ymax != ''):
        mplt.ylim(self.ymin, self.ymax)

      # If newplot is true, assign differing values
      if self.newplot == True: 
        mplt.xlabel(self.xlabel)
        mplt.ylabel(self.ylabel)
        mplt.title(self.title)

      # save figure, if True and set legend
      if self.saveplot == True: 
        mplt.legend(loc=self.legLoc, borderaxespad=0).get_frame().set_lw(0.0)
        if self.figname != '':
          mplt.savefig(self.figname)

      return 


  # Appending multiple x, y data from multiple datafiles   
  def serializeXYdata(self):
      x = np.array([])
      y = np.array([])
      for idata in range(len(self.dataFile)):
        data = np.loadtxt(self.dataFile[idata])
        if data.size == data.shape[0]:
          data.resize(1,data.size)
        x_tmp = data[:,self.xCol]
        y_tmp = data[:,self.yCol[0]]
        x = np.append(x,x_tmp)
        y = np.append(y,y_tmp)
      x = x*self.xfac  
      y = y*self.yfac
      #return x[self.sRow:self.eRow],y[self.sRow:self.eRow] 
      return x,y 


# Read lua parameters from command file input file or default file params.py
if  (len(sys.argv) > 1):
  ifile = sys.argv[1]
else:
  ifile = 'params_plot.py'

if not os.path.isfile(ifile):
  print 'Input file %s is not found in current dir', ifile
  sys.exit(1)

print 'Reading input file: ',ifile

# Copy the parameters to a temporary file
call(['cp',ifile, 'params_plottmp.py'])
#of=open('simParam.py','w')
#call('lua printSimParam.lua', shell=True,stdout=of)
#of.close()
#
#from  simParam import *
from params_plottmp import *



##################### Start of Main loop ################
for iplot in range(len(plot)):
#  print '\nReading dataset :', iplot
  plt = Plot()
  plt.getPlotinfo(plot[iplot])
  if plt.newplot == True:
    mplt.figure(iplot+1)
  nData = plt.eRow - plt.sRow

  # Fast Fourier Transform plot
  if plt.kind == 'FFT':
    x, y = plt.getxy(iData=0,iCol=0)    
    interval = x[1] - x[0]
    fft = 2.*np.absolute(np.fft.fft(y)/nData)
    freq = np.fft.fftfreq(nData,d=interval)
    print 'Frequency ', abs(freq[fft.argmax()])
    x = freq[1:50]
    y = fft[1:50]

  # Pressure difference between two points/lines/planes
  elif plt.kind == 'deltaP':
    x, p1 = plt.getxy(iData=0,iCol=0)
    x, p2 = plt.getxy(iData=1,iCol=1) 
    y = p1 - p2
    print 'average deltaP',  np.average(y)
    print 'deltaP min:', np.min(y), 'max:', np.max(y) 

  # simple x, y plot
  elif plt.kind == 'xy':
    x, y = plt.serializeXYdata() 
    if plt.eRow == -1:
      plt.eRow = len(x) 
    # If s.Row is '-ve', starting row is calculated based on eRow
    if plt.sRow < 0:
      plt.sRow = plt.eRow + plt.sRow
    x = x[plt.sRow:plt.eRow]
    y = y[plt.sRow:plt.eRow]

  # Polynomial fitting for given data
  elif plt.kind == 'polyfit':
    x, y = plt.serializeXYdata() 
    coeff = np.polyfit(x,y,2)  #Polynomial order = 2
    func = np.poly1d(coeff)
    y = func(x)
    print 'roots', np.roots(func)  
    try:
      for iPnts in range(len(plot[iplot]['pnts'])):
        pnt = plot[iplot]['pnts'][iPnts]
        x = np.insert(x,len(x),pnt)
        y = np.insert(y,len(y),func(pnt))
    except:
      print 'No additional points specified'

  # Draw a line on plot
  elif plt.kind == 'drawline':
    x, y = plt.serializeXYdata() 
    y = np.ones(len(x))*plot[iplot]['y']

  # Finding recirculation lenth
  elif plt.kind == 'rLength':  
    x, y = plt.serializeXYdata() 
    neg_pos = np.argwhere(y<0)
    dx = plot[iplot]['dx']
    recirc_end = (x[neg_pos[-1]]+x[neg_pos[-1]+1])/2.0 
    #print x[neg_pos[-1]],x[neg_pos[-1]+1],recirc_end,plot[iplot]['ref_pos']
    print 'recirculation length:', (recirc_end - x[neg_pos[0]])[0]
    print 'recirculation length using dx:', (recirc_end - x[neg_pos[0]]-dx/2.0)[0]
    y = np.ones(len(x))*0.0

  else:
    print 'kind unknown'
    sys.exit(2)

  plt.plot2d(x, y)

if show_plot:        
  mplt.show()

# Delete temporary parameter files
call('rm params_plottmp.py *.pyc', shell = True, stderr=sys.stderr)
