import numpy as np
import matplotlib.pyplot as plt

# High-dimensional quadratic function: f(x) = 1/2 * ||Ax - b||^2
def high_dim_quadratic(x, A, b):
    return 0.5 * np.sum((np.dot(A, x) - b) ** 2)

# Gradient of the high-dimensional quadratic function
def grad_high_dim_quadratic(x, A, b):
    return A.T @ (A @ x - b)

# Parameters for the high-dimensional problem
dim = 256  # High dimension
A = np.random.randn(dim*2, dim)
b = np.random.randn(dim*2)

# Initial point in high-dimensional space
x0 = np.random.randn(dim)

# Learning rate and initial momentum parameters for Heavy Ball method
learning_rate = 1 / np.linalg.norm(A, ord=2) ** 2  # 1/Lipschitz constant
momentum = 0.7
iterations = 100

# Storing trajectories for plotting
gd_trajectory = []
hb_trajectory = []

# Gradient Descent (for comparison)
x_gd = x0.copy()
for i in range(iterations):
    grad = grad_high_dim_quadratic(x_gd, A, b)
    x_gd = x_gd - learning_rate * grad
    gd_trajectory.append(x_gd) 

# Heavy Ball method
x_hb = x0.copy()
v_hb = np.zeros_like(x0)

for k in range(iterations):
    grad = grad_high_dim_quadratic(x_hb, A, b)
    v_hb = momentum * v_hb - learning_rate * grad
    x_hb = x_hb + v_hb

    # Store the trajectory
    hb_trajectory.append(x_hb)

# Convert trajectories to numpy arrays for easy plotting
gd_trajectory = np.array(gd_trajectory)
hb_trajectory = np.array(hb_trajectory)

plt.figure(figsize=(8, 6))

# Plot the difference in terms of objectives
sol = np.linalg.solve(A.T @ A, A.T @ b)
gd_objectives = [high_dim_quadratic(x, A, b) - high_dim_quadratic(sol, A, b) for x in gd_trajectory]
hb_objectives = [high_dim_quadratic(x, A, b) - high_dim_quadratic(sol, A, b) for x in hb_trajectory]

plt.semilogy(range(iterations), gd_objectives, label="Gradient Descent", color="red", linewidth=2)
plt.semilogy(range(iterations), hb_objectives, label="Heavy Ball Method", color="blue", linewidth=2)
plt.xlabel('Iteration', fontsize=16)
plt.ylabel('Lack of Optimality', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
if True:
    plt.savefig("heavy-ball.png", dpi=300)
plt.show()
