import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from matplotlib.cm import viridis

# Function to calculate the midpoints of a triangle
def midpoints(p1, p2, p3):
    m1 = (p1 + p2) / 2
    m2 = (p2 + p3) / 2
    m3 = (p1 + p3) / 2
    return m1, m2, m3

# Function to compute all triangles at a given depth
def sierpinski_iteration(vertices, depth):
    triangles = [vertices]
    for _ in range(depth):
        new_triangles = []
        for triangle in triangles:
            p1, p2, p3 = triangle
            m1, m2, m3 = midpoints(p1, p2, p3)
            # Create the three smaller triangles
            new_triangles.append([p1, m1, m3])
            new_triangles.append([m1, p2, m2])
            new_triangles.append([m3, m2, p3])
        triangles = new_triangles
    return triangles

# Set up the plot
fig, ax = plt.subplots()
ax.set_aspect('equal')
ax.axis('off')
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.1, np.sqrt(3) / 2 + 0.1)

# Initial triangle vertices
vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3) / 2]])
max_depth = 6

# Precompute triangles for each refinement level
refinement_steps = [sierpinski_iteration(vertices, depth) for depth in range(max_depth + 1)]

# Update function for animation
def update(frame):
    ax.clear()  # Clear the previous frame
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, np.sqrt(3) / 2 + 0.1)
    # Get the triangles for the current frame
    triangles = refinement_steps[frame]
    # Draw each triangle with colors
    for i, triangle in enumerate(triangles):
        p1, p2, p3 = triangle
        color = viridis(i / len(triangles))  # Assign color based on index
        ax.fill(
            [p1[0], p2[0], p3[0]],
            [p1[1], p2[1], p3[1]],
            color=color,
            edgecolor="black",
            linewidth=0.5,
            alpha=0.7,
        )

# Create the animation
ani = animation.FuncAnimation(
    fig, update, frames=len(refinement_steps), interval=500, repeat=False
)

if True:
    # save at mp4
    ani.save('sierpinski_triangle.mp4', writer='ffmpeg', fps=0.5)

# Show the animation
plt.show()