import numpy as np
from scipy import integrate
from matplotlib.pylab import *
import math
from scipy.optimize import fsolve
import matplotlib.pyplot as mplt

def MaxwellStefan(t,y):
  global N1, N2, N3, ref_Diff
  N3=0
  #F12 = 1.286e-3
  #F13 = 2.081e-3
  #F23 = 3.019e-3
  F12 = 8.48e-6/ref_diff
  F13 = 13.72e-6/ref_diff
  F23 = 19.91e-6/ref_diff
  
  n = len(y)
  ydot = np.zeros((n,1))

  ydot[0]=(y[0]*N2-y[1]*N1)/F12+(y[0]*N3-(1-y[0]-y[1])*N1)/F13
  ydot[1]=(y[1]*N1-y[0]*N2)/F12+(y[1]*N3-(1-y[0]-y[1])*N2)/F23
  return ydot

def ode45(N):
  global N1, N2, N3, delta_t

  N1=N[0]
  N2=N[1]
  
  r = integrate.ode(MaxwellStefan).set_integrator('vode', method='bdf')

  # set time range
  t_start = 0.0
  t_final = 1.0

  # Number of time steps: 1 extra for initial condition
  num_steps = np.floor((t_final - t_start)/delta_t) + 1

  # initial condition
  y_initial = np.array([0.319, 0.528])

  r.set_initial_value(y_initial, t_start)

  # Additional Python step: create vectors to store trajectories
  t = np.zeros((num_steps, 1))
  y = np.zeros((num_steps, 2))
  t[0] = t_start
  y[0,0] = y_initial[0]
  y[0,1] = y_initial[1]

  #integrate the ODE(s) across each delta_t timestep
  k = 1
  while r.successful() and k < num_steps:
    r.integrate(r.t + delta_t)
    # Store the results to plot later
    t[k] = r.t
    y[k,0] = r.y[0]
    y[k,1] = r.y[1]
    k += 1
 
  return t, y


def shooting(N):
  [t,y] = ode45(N)
  return np.array([y[-1,0],y[-1,1]])


global N1, N2, N3, delta_t, ref_diff
ref_diff = 1e-6#0.0065940902021772935 
N3=0
nPoints = 100.0
delta_t = 1.0/nPoints
print('nPoints ', nPoints)
print('delta_t ', delta_t)
t=np.array([delta_t,delta_t])
N = fsolve(shooting, t, xtol=1e-13) 
print('Molefraction N1, N2',N)

[t,y] = ode45(N)

# write molefraction to file
filename = 'maxwell_stefan_reference.res'
fid = open(filename, "w")
fid.write('#x \t Acetone \t Methanol \t Air \n')
for iData in range(len(t)):
  data = str(t[iData][0])+'\t'+str(y[iData,0])+'\t'+ str(y[iData,1])+ '\t'+str(1.0-y[iData,0]-y[iData,1])+'\n'
  fid.write(data)
fid.close()  

#print(len(y[:,0]))
# All done!  Plot the trajectories:
#plt=mplt.figure(1)
#mplt.plot(t, y[:,0])
#mplt.plot(t, y[:,1])
#mplt.plot(t, 1.0-y[:,0]-y[:,1])
#
#mplt.grid(True,which="major",ls="-")
#mplt.xlim(0.0,1.0)
#mplt.ylim(0.0,1.0)
#mplt.show()
