import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the underdetermined linear system
np.random.seed(3)
m, n = 3, 10  # m equations, n variables (m < n)
A = np.random.randn(m, n) / np.sqrt(m)
b = np.random.randn(m)

# Define the quadratic loss function and its gradient
def loss(x):
    return 0.5 * np.sum((A @ x - b) ** 2)

def gradient(x):
    return A.T @ (A @ x - b)

# Gradient descent parameters
L = np.linalg.norm(A, ord=2) ** 2  # Lipschitz constant of the gradient
learning_rate = 0.3 * 1 / L
num_iterations = 50
num_initializations = 7  # Number of random initializations

# Generate random initializations
np.random.seed(0)
initial_points = [A.T @ np.random.randn(m) for _ in range(num_initializations)]
trajectories = []

# Perform gradient descent for each initialization
for x_init in initial_points:
    trajectory = [x_init]
    x = x_init
    for _ in range(num_iterations):
        grad = gradient(x)
        x = x - learning_rate * grad
        trajectory.append(x)
    trajectories.append(np.array(trajectory))

# Minimal norm solution (reference for the animation)
x_min_norm = np.linalg.pinv(A) @ b

# Set up the figure
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlim(-1, 2)
ax.set_ylim(0, 2)
# ax.set_title("Implicit Bias of Gradient Descent: Minimal Norm Solution")
# ax.set_xlabel("$x_1$")
# ax.set_ylabel("$x_2$")
ax.set_xticks([])
ax.set_yticks([])
# ax.grid(True)

# Plot the feasible set
u, s, vt = np.linalg.svd(A)
null_space = vt[s.size:].T
for t in np.linspace(-2.5, 2.5, 200):
    feasible_point = x_min_norm + t * null_space[:, 0]
    ax.plot(feasible_point[0], feasible_point[1], 'b.', alpha=0.2, linewidth=0.4)

# Initialize points and trajectory lines for each initialization
points = [ax.plot([], [], 'ro')[0] for _ in range(num_initializations)]
lines = [ax.plot([], [], '-', lw=2)[0] for _ in range(num_initializations)]
min_norm_point, = ax.plot(x_min_norm[0], x_min_norm[1], 'go', markersize=10, label='Minimal Norm Solution')

# Animation update function
def update(frame):
    for i, trajectory in enumerate(trajectories):
        points[i].set_data(trajectory[frame, 0], trajectory[frame, 1])
        lines[i].set_data(trajectory[:frame+1, 0], trajectory[:frame+1, 1])
    return points + lines

# Create the animation
ani = FuncAnimation(fig, update, frames=num_iterations + 1, interval=100, blit=True)

# Add legend and display the plot
# ax.legend()
for spine in ax.spines.values():
    spine.set_visible(False)

if True:
    ani.save('implicit_bias_animation.mp4', writer='ffmpeg')

plt.show()
