import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

# konstanter
x0 = 0  # første punkt i x-retning
xf = 1  # siste punkt i x-retning
y0 = 0  # første punkt i y-retning
yf = 1  # siste punkt i y-retning
t0 = 0  # første punkt i t-retning
tf = 0.08   # siste punkt i t-retning
Nx = 25  # antall x-punkter 
Ny = 25  # antall y-punkter
Nt = 2501  # antall t-punkter
l = (xf-x0)/(Nx-1)  # x-gitterbredde
m = (yf-y0)/(Ny-1)  # y-gitterbredde
n = (tf-t0)/(Nt-1)  # t-gitterbredde 
A = n/(l**2)
B = n/(m**2)
x = np.linspace(x0,xf,Nx)
y = np.linspace(y0,yf,Ny)
t = np.linspace(t0,tf,Nt)



# initialkrav
def f(y,x):
    return np.cos(2*np.pi*x)*np.cos(np.pi*y) - 0.7*x

u = np.zeros((Nx,Ny))
for i in range(1,Nx-1):
    for j in range(1,Ny-1):
        u[i][j] = f(x[i],y[j])
# Neumann-randkrav
for i in range(Nx):
    u[i][0] = u[i][1]
    u[i][-1] = u[i][-2]
for j in range(Ny):
    u[0][j] = u[1][j]
    u[-1][j] = u[-2][j]



# neste tidssteg
def u_next(u):
    un = np.zeros((Nx,Ny))
    for i in range(1,Nx-1):
        for j in range(1,Ny-1):
            un[i][j] = (1-2*A-2*B)*u[i][j] + A*u[i-1][j] + A*u[i+1][j] + B*u[i][j-1] + B*u[i][j+1]
    # Mer Neumann
    for i in range(Nx):
        un[i][0] = un[i][1]
        un[i][-1] = un[i][-2]
    for j in range(Ny):
        un[0][j] = un[1][j]
        un[-1][j] = un[-2][j]
    return un


matrise = np.zeros((Nx,Ny,Nt))
matrise[:,:,0] = u[:,:]
for i in range(Nt-1):
    matrise[:,:,i+1] = u_next(matrise[:,:,i])




# animering
fig, ax = plt.subplots(subplot_kw={"projection": "3d"},dpi=200)
T_ani = 4   # tid for animasjonen i sekunder
fps = 40    # fps
total_frames = T_ani*fps
meshX, meshY = np.meshgrid(x, y)

def update(i):
    ax.clear()
    surf = ax.plot_surface(meshX, meshY, matrise[:,:,i])
    ax.set_xlabel("$y$")
    ax.set_ylabel("$x$")
    ax.set_zlabel("$T$")
    ax.set_xlim((x0,xf))
    ax.set_ylim((y0,yf))
    ax.set_zlim((-1.1,1.1))
    return surf




ani = animation.FuncAnimation(fig, update, repeat=True, interval=1000/fps, 
                              frames=np.linspace(0,Nt-1,total_frames).astype(int))
ani.save("animation2d_oblig.gif")
plt.show()


