import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the quadratic function and its gradient
def f(x, y):
    return x**2 + 2*y**2 - x * y

def grad_f(x, y):
    return np.array([2*x - y, 4*y - x])

# Gradient Descent update rule
def gd_update(x, y, alpha):
    g = grad_f(x, y)
    return x - alpha * g[0], y - alpha * g[1]

# Nesterov Accelerated Gradient (NAG) update rule
def nag_update(x, y, v_x, v_y, alpha, gamma):
    look_ahead_x = x + gamma * v_x
    look_ahead_y = y + gamma * v_y
    g_look_ahead = grad_f(look_ahead_x, look_ahead_y)
    new_v_x = gamma * v_x - alpha * g_look_ahead[0]
    new_v_y = gamma * v_y - alpha * g_look_ahead[1]
    return x - alpha * g_look_ahead[0], y - alpha * g_look_ahead[1], new_v_x, new_v_y

# Stochastic Gradient Descent (SGD) update rule
def sgd_update(x, y, alpha):
    noise = 0.5 * np.random.randn(2)
    g = grad_f(x, y)
    return x - alpha * (g[0] + noise[0]), y - alpha * (g[1] + noise[1])
    # idx = np.random.choice([0, 1])
    # if idx == 0:
    #     return x - alpha * grad_f(x, y)[0], y
    # else:
    #     return x, y - alpha * grad_f(x, y)[1]

# Parameters
alpha_gd = 0.05
alpha_nag = 0.05
gamma_nag = 0.001
alpha_sgd = 0.05

# Initialize positions and velocities for Nesterov
x_gd, y_gd = -2, 3
x_nag, y_nag = 2, 3
x_sgd, y_sgd = 2, -3
v_x, v_y = 0, 0

# Track the path of each method
path_gd = [(x_gd, y_gd)]
path_nag = [(x_nag, y_nag)]
path_sgd = [(x_sgd, y_sgd)]

# Perform updates and track paths
for _ in range(100):
    x_gd, y_gd = gd_update(x_gd, y_gd, alpha_gd)
    path_gd.append((x_gd, y_gd))
    
    x_nag, y_nag, v_x, v_y = nag_update(x_nag, y_nag, v_x, v_y, alpha_nag, gamma_nag)
    path_nag.append((x_nag, y_nag))
    
    x_sgd, y_sgd = sgd_update(x_sgd, y_sgd, alpha_sgd)
    path_sgd.append((x_sgd, y_sgd))

# Define the range and step size for X and Y
x_vals = np.linspace(-3, 3, 400)
y_vals = np.linspace(-3, 3, 400)
X, Y = np.meshgrid(x_vals, y_vals)

# Compute the function values on the grid
Z = f(X, Y)

# Convert paths to numpy arrays for easier plotting
path_gd = np.array(path_gd)
path_nag = np.array(path_nag)
path_sgd = np.array(path_sgd)

# Create plot and set limits
fig, ax = plt.subplots()
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

# Plot contour lines
contour = ax.contour(X, Y, Z, levels=20, cmap='viridis')

line_gd, = ax.plot([], [], 'r-', label='GD')
point_gd, = ax.plot([], [], 'ro')
line_nag, = ax.plot([], [], 'b-', label='NAG')
point_nag, = ax.plot([], [], 'bo')
line_sgd, = ax.plot([], [], 'g-', label='SGD')
point_sgd, = ax.plot([], [], 'go')

def init():
    line_gd.set_data([], [])
    point_gd.set_data([], [])
    line_nag.set_data([], [])
    point_nag.set_data([], [])
    line_sgd.set_data([], [])
    point_sgd.set_data([], [])
    return (line_gd, point_gd, line_nag, point_nag, line_sgd, point_sgd)

def animate(i):
    line_gd.set_data(path_gd[:i, 0], path_gd[:i, 1])
    point_gd.set_data(path_gd[i-1:i, 0], path_gd[i-1:i, 1])
    line_nag.set_data(path_nag[:i, 0], path_nag[:i, 1])
    point_nag.set_data(path_nag[i-1:i, 0], path_nag[i-1:i, 1])
    line_sgd.set_data(path_sgd[:i, 0], path_sgd[:i, 1])
    point_sgd.set_data(path_sgd[i-1:i, 0], path_sgd[i-1:i, 1])
    return (line_gd, point_gd, line_nag, point_nag, line_sgd, point_sgd)

ani = FuncAnimation(fig, animate, frames=range(1, len(path_gd)), init_func=init, blit=True)
plt.legend()

if False:
    ani.save("gd_nag_sgd.mp4", fps=10)

plt.show()