# Integrator for KdV using FFT
# 
# an explanation of the method is linked from the course webpage
# 
# try changing N a little way down; you can also experiment
# with completely different initial conditions and see
# what happens...
# Program originally written by Sam Webster (2007); small modifications
# by PED in 2010

from __future__ import division
from scipy import *
import numpy.fft as fft
import Gnuplot
import time


# Initial configuration
def u0(x):
    return N*(N+1)/cosh(x)**2

# Plot the curve with -L/2<x<L/2
def plot(uu):
    dat=[]
    for i in xrange(M):
        dat.append([xmin+i*h,uu[i]])
    g.plot(Gnuplot.Data(dat,with_='lines'))

# Calculate the right-hand-side of the Uhat ODE
# Note that the FFT routines need the function u
# to have period 2 pi; the factors of 2*pi/L take this
# into account by rescaling the axes to give period L

def f(tt,uu):
    s=(2*pi/L)**3
    ee=exp(s*1j*(k**3)*tt)
    a1=fft.ifft(ee*uu)
    a2=fft.fft(a1**2)
    return -3j*(1/ee)*k*a2*2*pi/L

# The value of N, which parameterises the initial configuration
global N
N=2.25
while 1:
 answer=raw_input("Value of N? ")
 try:
  N=float(answer)
  break
 except ValueError:
  print "that wasn't a number"

# Number of lattice points used for the
# Fourier Transform (should be a power of 2)   
M=512

# X ranges
L=40
xmin=-L/2
xmax=L/2
h=(xmax-xmin)/M

# Create a Gnuplot window
g=Gnuplot.Gnuplot(persist=1)
g('set terminal x11')

# Set up initial configuration (U=u at t=0)
u=[]
for i in xrange(M):
    u.append(u0(xmin+i*h))

# Uhat(k,t) = exp(-i(pi k /L)^3 t) uhat(k,t)
# where uhat(k,t) is the FT of u(x,t).
# Therefore Uhat(k,0)=uhat(k,0), so we just
# have to Fourier Transform the initial data:
Uhat=fft.fft(u)

# The function is periodic, hence the FT is discrete.
# The FT assumes the period to be 2 pi (hence the scaling)
# and so the values of k range from -M/2 to M/2.
# We encode these in an array in the particular way that
# the FFT routine requires,
# k=[1,2,3,....,M/2-1,0,-M/2+1,-M/2+2,...,-1]
# The k=M/2 term is removed for technical reasons.

k=zeros(M)
for i in xrange(M):
    if i<M/2:
        k[i]=i
    elif i==M/2:
        k[i]=0
    else:
        k[i]=(i-M)

#############################
#    INTEGRATION ROUTINE    #
#############################

# Range of time, and number of time steps
tmin=0.0
tmax=1.0
TSTEPS=12000
dt=(tmax-tmin)/TSTEPS

# Store every Kth configuration for plotting
u_data=[]
K=TSTEPS/200
c=0
opd=0
OK=1

print 'Integrating...'
# Start the loop
for t in arange(tmin,tmax+dt,dt):
    
# Solve (d/dt)Uhat(k,t)=f(t,Uhat) using
# a 4th-order Runge-Kutta in time, where
# f(,) is given by the above routine

    k1=f(t,Uhat)
    k2=f(t+0.5*dt,Uhat+0.5*dt*k1)
    k3=f(t+0.5*dt,Uhat+0.5*dt*k2)
    k4=f(t+dt,Uhat+dt*k3)
    Uhat=Uhat+(dt/6)*(k1+2*k2+2*k3+k4)


# Store every Kth configuration in the u array
    if c%K==0:
        s=(2*pi/L)**3
        e=exp(s*1j*(k**3)*t)    
        uhat=e*Uhat
        u=fft.ifft(uhat)
        u_data.append([t,u.real])
    c+=1

# Has it blown up?
    um=abs(u[:])
    um.sort()
    if not um[-1]<1000:
        print 'Unstable... decrease the step size!'
        OK=0
    if not OK: break

# How far have we got?
    pd=c*100/TSTEPS
    if pd>=opd+10:
        print int(pd),'% done'
        opd=pd

# Now play the movie, in correct time:
print 'Plotting...'
for data in u_data:
# Check the time:
    ti=time.time()
# Retrieve the u(t) data
    t=data[0]
    u=data[1]
# Plot it
    ts=str(t)
    ts+='0'*(6-len(ts))
    g('set yrange['+str(-0.5*N*(N+1))+':'+str(1.5*N*(N+1))+']')
    g.title('N='+str(N)+', t='+ts)
    plot(u.real)
# Wait for time ti_K*dt, slowed down by a factor of 5
    while (time.time()<ti+5*K*dt): 1
 

