import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import Delaunay
from matplotlib.animation import FuncAnimation

def make_zonotope(generators):
    """
    Given a generator matrix, compute the vertices of the zonotope.
    """
    n, d = generators.shape
    vertices = []
    for i in range(2**n):
        vertex = np.zeros(d)
        for j in range(n):
            if (i >> j) & 1:
                vertex += generators[j]
            else:
                vertex -= generators[j]
        vertices.append(vertex)
    return np.array(vertices)

def animate_zonotope_construction(n):
    """
    Animate the construction of a zonotope by adding one generator at a time.
    """
    fig, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.set_axis_off()
    
    def update(frame):
        # n-root of the unit on 2D plane x,y 
        basis_angle = 2 * np.pi / (2*frame+3)
        # Normalize generators by 1/sqrt(n) to maintain consistent diameter
        scale_factor = 1.0 / (frame+1)
        generators = np.array([[np.cos(basis_angle * i) * scale_factor, 
                              np.sin(basis_angle * i) * scale_factor] 
                              for i in range(2*frame+3)])
        # add two unit vectors
        generators = np.vstack([generators, np.eye(2)])

        ax.clear()
        ax.set_aspect('equal')
        ax.set_axis_off()
        
        # Use generators up to the current frame
        if len(generators) > 0:
            for gen in generators:
                ax.quiver(0, 0, gen[0], gen[1], angles='xy', scale_units='xy', scale=1, color='r')
                ax.quiver(0, 0, -gen[0], -gen[1], angles='xy', scale_units='xy', scale=1, color='r')

            vertices = make_zonotope(generators)
            
            # Plot the convex hull
            hull = Delaunay(vertices).convex_hull
            for simplex in hull:
                ax.plot(vertices[simplex, 0], vertices[simplex, 1], 
                       color=plt.cm.viridis(frame/len(generators)))
        
        # Set consistent axes limits based on full zonotope
        full_vertices = make_zonotope(generators)
        margin = 0.1 * (np.max(full_vertices) - np.min(full_vertices))
        ax.set_xlim(np.min(full_vertices[:, 0]) - margin, 
                   np.max(full_vertices[:, 0]) + margin)
        ax.set_ylim(np.min(full_vertices[:, 1]) - margin, 
                   np.max(full_vertices[:, 1]) + margin)
        # ax.set_title(f'Zonotope Construction (Generator {frame+1}/{n})')
    
    anim = FuncAnimation(fig, update, frames=n,
                        interval=1000, repeat=False)
    if True:
        # write a 1fps video
        anim.save('zonotope_construction.mp4', writer='ffmpeg', fps=1)
    plt.show()

if __name__ == "__main__":    
    # Animate zonotope construction
    animate_zonotope_construction(7)