import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the shearing matrix M
M = np.array([[2, 1], [0, 1]])

# Perform the SVD
U, Sigma, Vt = np.linalg.svd(M)
V = Vt.T

# Generate points on a unit circle
theta = np.linspace(0, 2 * np.pi, 100)
circle = np.array([np.cos(theta), np.sin(theta)])

# Unit vectors
unit_vectors = np.eye(2)

# Apply transformations
circle_transformed = M @ circle
circle_Vt = V.T @ circle
circle_scaled = np.diag(Sigma) @ circle_Vt

unit_vectors_transformed = M @ unit_vectors
unit_vectors_Vt = V.T @ unit_vectors
unit_vectors_scaled = np.diag(Sigma) @ unit_vectors_Vt
unit_vectors_final = U @ unit_vectors_scaled

# Setup the plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.set_aspect('equal')
ax.grid(True)

# Initialize plot elements
circle_plot, = ax.plot([], [], 'b-', lw=2, label="Unit Circle")
ellipse_plot, = ax.plot([], [], 'r-', lw=2, label="Transformed Circle")
vec1_plot, = ax.plot([], [], 'g-', lw=2, label="Canonical Vectors")
vec2_plot, = ax.plot([], [], 'm-', lw=2)

# Animation update function
def update(frame):
    if frame < 50:  # Rotate V*
        t = frame / 50
        circle_intermediate = (1 - t) * circle + t * (V.T @ circle)
        vec_intermediate = (1 - t) * unit_vectors + t * unit_vectors_Vt
        # Add "Rotation V" label during first phase
        ax.set_title("Rotation $V^\\top$", fontsize=20)
    elif frame < 100:  # Scale Σ
        t = (frame - 50) / 50
        circle_intermediate = (1 - t) * (V.T @ circle) + t * (np.diag(Sigma) @ (V.T @ circle))
        vec_intermediate = (1 - t) * unit_vectors_Vt + t * unit_vectors_scaled
        ax.set_title("Scaling $\Sigma$", fontsize=20)
    else:  # Rotate U
        t = (frame - 100) / 50
        circle_intermediate = (1 - t) * (np.diag(Sigma) @ (V.T @ circle)) + t * circle_transformed
        vec_intermediate = (1 - t) * unit_vectors_scaled + t * unit_vectors_final
        ax.set_title("Rotation $U$", fontsize=20)
    
    # Update plot data
    circle_plot.set_data(circle_intermediate[0, :], circle_intermediate[1, :])
    vec1_plot.set_data([0, vec_intermediate[0, 0]], [0, vec_intermediate[1, 0]])
    vec2_plot.set_data([0, vec_intermediate[0, 1]], [0, vec_intermediate[1, 1]])
    return circle_plot, vec1_plot, vec2_plot

# Create animation
anim = FuncAnimation(fig, update, frames=150, interval=50, blit=False)

if True:
    anim.save('svd.mp4', writer='ffmpeg', fps=30)

plt.grid(False)
plt.show()
