import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Smoothstep function for interpolation
def smoothstep(t):
    return 3 * t**2 - 2 * t**3

# Parameters for the grid
grid_size = (10, 10)
vector_length = 0.5
vector_color = 'red'

# Generate random directions for unit vectors at each grid node
angles = np.random.uniform(0, 2 * np.pi, grid_size)
x = np.arange(grid_size[1])
y = np.arange(grid_size[0])
X, Y = np.meshgrid(x, y)

# Compute vector components
U = np.cos(angles) * vector_length
V = np.sin(angles) * vector_length

# Fine grid for interpolating between nodes
fine_grid_resolution = 10
x_fine = np.linspace(0, grid_size[1] - 1, (grid_size[1] - 1) * fine_grid_resolution)
y_fine = np.linspace(0, grid_size[0] - 1, (grid_size[0] - 1) * fine_grid_resolution)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

# Nearest grid node coordinates for each fine grid point
nearest_X = np.round(X_fine).astype(int)
nearest_Y = np.round(Y_fine).astype(int)

# Dot product with the nearest grid node gradient
nearest_U = U[nearest_Y, nearest_X]
nearest_V = V[nearest_Y, nearest_X]
delta_X = X_fine - nearest_X
delta_Y = Y_fine - nearest_Y
dot_products = nearest_U * delta_X + nearest_V * delta_Y

# Calculate the nearest lower-left grid nodes for each fine grid point
x0 = np.floor(X_fine).astype(int)
y0 = np.floor(Y_fine).astype(int)
x0 = np.clip(x0, 0, grid_size[1] - 2)
y0 = np.clip(y0, 0, grid_size[0] - 2)

# Dot products at each of the four corners of each cell
dot00 = U[y0, x0] * (X_fine - x0) + V[y0, x0] * (Y_fine - y0)
dot10 = U[y0, x0 + 1] * (X_fine - (x0 + 1)) + V[y0, x0 + 1] * (Y_fine - y0)
dot01 = U[y0 + 1, x0] * (X_fine - x0) + V[y0 + 1, x0] * (Y_fine - (y0 + 1))
dot11 = U[y0 + 1, x0 + 1] * (X_fine - (x0 + 1)) + V[y0 + 1, x0 + 1] * (Y_fine - (y0 + 1))

# Interpolate between dot products
sx = smoothstep(X_fine - x0)
sy = smoothstep(Y_fine - y0)
interp_x0 = dot00 * (1 - sx) + dot10 * sx
interp_x1 = dot01 * (1 - sx) + dot11 * sx
interpolated = interp_x0 * (1 - sy) + interp_x1 * sy

# Create a figure for animation
fig, ax = plt.subplots(figsize=(6, 6))
fig.colorbar(ax.imshow(dot_products, extent=(-1, grid_size[1], -1, grid_size[0]), origin='lower', cmap='coolwarm', alpha=0.6))

def update(frame):
    ax.clear()  # Clear the axis to prevent overlapping
    ax.set_xlim(-1, grid_size[1])
    ax.set_ylim(-1, grid_size[0])
    ax.set_aspect('equal', adjustable='box')
    ax.set_xticks(np.arange(0, grid_size[1], 1))
    ax.set_yticks(np.arange(0, grid_size[0], 1))
    ax.grid(True, which='both', color='gray', linestyle='--', linewidth=0.5)
    
    if frame == 0:
        # First frame: Display grid with unit vectors
        ax.quiver(X, Y, U, V, color=vector_color, angles='xy', scale_units='xy', scale=1)
        ax.set_title("Random unit vectors")
    elif frame == 1:
        # Second frame: Display dot products with nearest grid node gradient
        ax.quiver(X, Y, U, V, color=vector_color, angles='xy', scale_units='xy', scale=1)
        ax.imshow(dot_products, extent=(-1, grid_size[1], -1, grid_size[0]), origin='lower', cmap='coolwarm', alpha=0.6)
        ax.set_title("Dot products")
    elif frame == 2:
        # Third frame: Display interpolated noise
        ax.quiver(X, Y, U, V, color=vector_color, angles='xy', scale_units='xy', scale=1)
        ax.imshow(interpolated, extent=(-1, grid_size[1], -1, grid_size[0]), origin='lower', cmap='coolwarm', alpha=0.6)
        ax.set_title("Interpolated noise (smoothstep)")

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=3, interval=2000, blit=False, repeat=True)

if True:
    # save as mp4 with 1 fps
    ani.save('perlin_noise.mp4', writer='ffmpeg', fps=0.5)

plt.show()