import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.patches import FancyArrowPatch

# Define the rotated quadratic function with cross-term cxy
# Rescale the coefficients a, b, and c for smaller values to fit within the -2 to 2 range
def f(x, y, a=0.1, b=0.3, c=0.05):
    return a * x**2 + b * y**2 + c * x * y

# Gradient of the rotated quadratic function
def grad_f(x, y, a=0.1, b=0.3, c=0.05):
    df_dx = 2 * a * x + c * y
    df_dy = 2 * b * y + c * x
    return np.array([df_dx, df_dy])

# Gradient descent update rule
def gradient_descent_step(xy, learning_rate, a=0.1, b=0.3, c=0.05):
    grad = grad_f(xy[0], xy[1], a, b, c)
    return xy - learning_rate * grad

# Set up figure and axes for animation
fig, ax = plt.subplots()
x_vals = np.linspace(-2, 2, 400)  # Rescale x-axis range
y_vals = np.linspace(-2, 2, 400)  # Rescale y-axis range
X, Y = np.meshgrid(x_vals, y_vals)
Z = f(X, Y)

# Contour plot of the rotated quadratic function
ax.contour(X, Y, Z, levels=50, cmap='viridis')

# Initial points for gradient descent
x0, y0 = 1.5, 1.5  # Starting point
learning_rate = 0.1  # Adjust learning rate for the smaller scale
a, b, c = 0.1, 0.3, 0.05  # Rescaled coefficients for anisotropy and rotation

# The limit point (minimum of the quadratic function)
x_lim = np.array([0, 0])

# To store the gradient descent path
path_x, path_y = [x0], [y0]

point, = ax.plot([], [], 'ro', label='Current Point')  # point on the contour
grad_line, = ax.plot([], [], 'b--', label='Gradient Direction')  # gradient line
path_line, = ax.plot([], [], 'r-', label='Gradient Descent Path')  # full path

secant_arrow = None  # Will hold the arrow object

# Set the axis limits to be between -2 and 2
ax.set_xlim([-2, 2])
ax.set_ylim([-2, 2])
ax.set_xlabel('x')
ax.set_ylabel('y')

# Placeholder secant arrow to include in the legend
secant_patch = FancyArrowPatch((0, 0), (1, 1), color='g', label='Secant Vector')
ax.add_patch(secant_patch)

# Create a legend that includes the secant vector
ax.legend(handles=[point, grad_line, path_line, secant_patch], loc='upper left')
secant_patch.set_visible(False)

# Initialize the plot
def init():
    point.set_data([], [])
    grad_line.set_data([], [])
    path_line.set_data([], [])
    return point, grad_line, path_line

# Compute the secant vector, which is the normalized direction towards the limit (0, 0)
def compute_secant_vector(x0, y0, x_lim):
    vec = np.array([x0, y0]) - x_lim
    norm = np.linalg.norm(vec)
    if norm == 0:
        return np.array([0, 0])  # Avoid division by zero at the limit
    return vec / norm

# Update function for the animation
def update(frame):
    global x0, y0, secant_arrow
    # Update points using gradient descent
    xy0_new = gradient_descent_step(np.array([x0, y0]), learning_rate, a, b, c)

    # Append the new position to the path
    path_x.append(xy0_new[0])
    path_y.append(xy0_new[1])

    # Update the descent point
    point.set_data([xy0_new[0]], [xy0_new[1]])

    # Update gradient direction line (tangent direction)
    grad = grad_f(x0, y0, a, b, c)
    grad_x_vals = np.linspace(x0, x0 - grad[0], 100)
    grad_y_vals = np.linspace(y0, y0 - grad[1], 100)
    grad_line.set_data(grad_x_vals, grad_y_vals)

    # Compute and plot the secant vector as an arrow using FancyArrowPatch
    if secant_arrow:
        secant_arrow.remove()  # Remove the previous arrow

    secant_vec = compute_secant_vector(x0, y0, x_lim)
    secant_arrow = FancyArrowPatch(
        (x0, y0), (x0 + secant_vec[0], y0 + secant_vec[1]),
        color='g', mutation_scale=10
    )
    ax.add_patch(secant_arrow)

    # Update the full path line
    path_line.set_data(path_x, path_y)

    # Shift points for next frame
    x0, y0 = xy0_new[0], xy0_new[1]

    return point, grad_line, path_line, secant_arrow

# Create the animation
ani = FuncAnimation(fig, update, frames=200, init_func=init, interval=200)

# Save the animation as a gif file
if True:
    ani.save('thom_conjecture.mp4', writer='ffmpeg', fps=10)