import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the circle (convex set 1)
circle_center = np.array([0, 0])
circle_radius = 1

# Define the line (convex set 2)
line_normal = np.array([1, 1]) / np.sqrt(2)  # Line at 45 degrees
line_point = np.array([1/np.sqrt(2)-0.1, 1/np.sqrt(2)-0.1])  # A point on the line

# Function to project a point onto the disk
def project_to_circle(point, center, radius):
    direction = point - center
    distance = np.linalg.norm(direction)
    if distance <= radius:
        return point
    else:
        return center + radius * direction / distance

# Function to project a point onto a line
def project_to_line(point, line_point, line_normal):
    point_to_line = point - line_point
    distance = np.dot(point_to_line, line_normal)
    return point - distance * line_normal

# Animation setup
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# ax.grid(True)
# ax.axhline(0, color='gray', linewidth=0.5, linestyle='--')
# ax.axvline(0, color='gray', linewidth=0.5, linestyle='--')

# Draw the circle
circle = plt.Circle(circle_center, circle_radius, color='blue', alpha=0.2, label='Circle (C1)')
ax.add_artist(circle)
# Add label for the circle
ax.text(circle_center[0] - circle_radius - 0.1, circle_center[1] - circle_radius - 0.1, '$C_1$', color='blue', fontsize=18)

# Draw the line
x_vals = np.linspace(-1.5, 1.5, 100)
y_vals = line_point[1] - (x_vals - line_point[0]) * line_normal[1] / line_normal[0]
ax.plot(x_vals, y_vals, color='red', linewidth=2, alpha=0.7, label='Line (C2)')
# Add label for the line
ax.text(line_point[0] + 0.1, line_point[1] + 0.1, '$C_2$', color='red', fontsize=18)

# Add point and projection lines
point, = ax.plot([], [], 'ko', markersize=10, label='Current Point')
line_to_circle, = ax.plot([], [], 'k--', alpha=0.5)
line_to_line, = ax.plot([], [], 'k--', alpha=0.5)

# Add a legend
# ax.legend()

# Initial point and iterations
start_point = np.array([1.2, -1.2])
current_point = start_point.copy()
points = [current_point]

# Compute projections alternately
num_steps = 20
for _ in range(num_steps):
    next_point = project_to_circle(current_point, circle_center, circle_radius)
    points.append(next_point)
    next_point = project_to_line(next_point, line_point, line_normal)
    points.append(next_point)
    current_point = next_point

# Update function for animation
def update(frame):
    if frame == 0:
        point.set_data(start_point[0], start_point[1])
        line_to_circle.set_data([], [])
        line_to_line.set_data([], [])
    else:
        idx = min(frame, len(points) - 1)
        current_point = points[idx]
        previous_point = points[idx - 1]

        if idx % 2 == 1:  # Projection onto the circle
            line_to_circle.set_data(
                [previous_point[0], current_point[0]],
                [previous_point[1], current_point[1]]
            )
            line_to_line.set_data([], [])
        else:  # Projection onto the line
            line_to_line.set_data(
                [previous_point[0], current_point[0]],
                [previous_point[1], current_point[1]]
            )
            line_to_circle.set_data([], [])

        point.set_data(current_point[0], current_point[1])

# Create the animation
ani = FuncAnimation(fig, update, frames=num_steps + 1, interval=200, repeat=False)

if True:
    ani.save('projection-convex-sets.mp4', writer='ffmpeg', fps=5)

plt.show()
