import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.optimize import minimize

# Define the "double hat" function Q(x, y) with two minima at x = 0 and smooth transition
def Q(x, y):
    return (1 - y**2)**2 + x * y

# Values for creating the animation
y = np.linspace(-2, 2, 400)
x_range = np.linspace(-1, 1, 301)
min_y_values = []

# Find the argmin of Q(x, y) as a function of x
for x in x_range:
    # Use scipy.optimize to find the minimum with respect to y
    result = minimize(lambda y: Q(x, y), 0.0)  # Initial guess y = 0.0
    min_y_values.append(result.x[0])

# Set up the figure and axis for the animation
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 12))

# Plot settings for the animation of y -> Q(x, y)
ax1.set_xlim(-2, 2)
ax1.set_ylim(-1, 5)
ax1.set_xlabel("y")
ax1.set_ylabel("Q(x, y)")
line, = ax1.plot([], [], 'k')
dots, = ax1.plot([], [], 'ro')

# Plot settings for x -> y_min (argmin)
ax2.set_xlim(-1, 1)
ax2.set_ylim(min(min_y_values) - 0.5, max(min_y_values) + 0.5)
ax2.set_xlabel("x")
ax2.set_ylabel("y (argmin of Q)")
x_line, = ax2.plot(x_range, min_y_values, 'b-', label="y minimizing Q(x, y)")
x_dot, = ax2.plot([], [], 'ro')
ax2.axvline(x=0, color='r', linestyle='--', label="Discontinuity at x = 0")
ax2.legend()

# Adjust frame sequence to accelerate from left to middle and slow down from middle to right, and pause at x = 0
frames = []
for i in range(len(x_range)):
    if i < len(x_range) // 2:
        frames.append(i)  # Accelerate to the middle
    elif i == len(x_range) // 2:
        frames.extend([i] * 20)  # Pause at the middle (x = 0)
    else:
        frames.append(i)  # Slow down from the middle

# Initialize the animation
def init():
    line.set_data([], [])
    dots.set_data([], [])
    x_dot.set_data([], [])
    return line, dots, x_dot

# Update the animation for each frame
def animate(i):
    x = x_range[i]
    Q_values = Q(x, y)
    line.set_data(y, Q_values)
    if x == 0:
        # Show two dots at the minima for x = 0
        dots.set_data([-1,1],[0,0])
        x_dot.set_data([0,0],[-1,1])
    else:
        dots.set_data(min_y_values[i], Q(x, min_y_values[i]))
        x_dot.set_data(x, min_y_values[i])
    ax1.set_title(f"x = {x:.2f}")
    return line, dots, x_dot

ani = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=50, blit=True)

plt.show()
