import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the quadratic function and its gradient
def f(x, y):
    return x**2 + 2*y**2 - x * y

def grad_f(x, y):
    return np.array([2*x - y, 4*y - x])

# Gradient Descent update rule
def gd_update(x, y, alpha):
    g = grad_f(x, y)
    return x - alpha * g[0], y - alpha * g[1]

# Parameters
alphas = [0.05, 0.1, 0.3]  # Different step sizes

# Initialize positions
initial_position = (-2, 3)

# Track the paths for each step size
paths = []

for alpha in alphas:
    x, y = initial_position
    path = [(x, y)]
    for _ in range(100):
        x, y = gd_update(x, y, alpha)
        path.append((x, y))
    paths.append(np.array(path))

# Define the range and step size for X and Y
x_vals = np.linspace(-3, 3, 400)
y_vals = np.linspace(-3, 3, 400)
X, Y = np.meshgrid(x_vals, y_vals)

# Compute the function values on the grid
Z = f(X, Y)

# Create plot and set limits
fig, ax = plt.subplots()
ax.set_xlim(-3, 1)
ax.set_ylim(-2, 3)

# Plot contour lines
contour = ax.contour(X, Y, Z, levels=30, cmap='viridis')

lines = []
points = []
colors = ['r', 'g', 'b', 'm']
labels = [f'GD alpha={alpha}' for alpha in alphas]

for color, label in zip(colors, labels):
    line, = ax.plot([], [], color+'-', label=label)
    point, = ax.plot([], [], color+'o')
    lines.append(line)
    points.append(point)

def init():
    for line, point in zip(lines, points):
        line.set_data([], [])
        point.set_data([], [])
    return lines + points

def animate(i):
    for path, line, point in zip(paths, lines, points):
        line.set_data(path[:i, 0], path[:i, 1])
        point.set_data(path[i-1:i, 0], path[i-1:i, 1])
    return lines + points

ani = FuncAnimation(fig, animate, frames=range(1, len(paths[0])), init_func=init, blit=True)
plt.legend()

if True:
    ani.save("gradient-descent-stepsize.mp4", fps=10)

plt.show()
