import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.animation as animation

G = nx.Graph()
edges = [
    (0, 1, 2),
    (0, 2, 3),
    (0, 3, 1),
    (1, 3, 4),
    (1, 4, 5),
    (2, 4, 6),
    (3, 4, 2)
]
G.add_weighted_edges_from(edges)

# Step-by-step Prim's algorithm implementation
def prim_mst_steps(graph):
    steps = []
    visited = set()
    edges = []
    
    # Start with an arbitrary node, here we choose node 0
    current_node = 0
    visited.add(current_node)
    steps.append((visited.copy(), edges.copy()))
    
    # Continue until we visit all nodes
    while len(visited) < len(graph.nodes):
        # Find the smallest edge that connects the visited set to a new node
        candidates = []
        for u in visited:
            for v, data in graph[u].items():
                if v not in visited:
                    candidates.append((u, v, data['weight']))
        if candidates:
            u, v, weight = min(candidates, key=lambda x: x[2])
            visited.add(v)
            edges.append((u, v))
            steps.append((visited.copy(), edges.copy()))
    
    return steps

# Create the animation
def update_graph(num, steps, pos, ax):
    ax.clear()
    visited, edges = steps[num]
    
    # Draw the graph up to this step
    nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=700, font_size=24, ax=ax)
    labels = nx.get_edge_attributes(G, 'weight')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels, font_size=16, ax=ax)
    
    # Highlight visited nodes
    nx.draw_networkx_nodes(G, pos, nodelist=visited, node_color='lightgreen', node_size=700, ax=ax)
    
    # Highlight the edges included in the MST so far
    nx.draw_networkx_edges(G, pos, edgelist=edges, edge_color='red', width=2, ax=ax)
    
    ax.set_title(f"Step {num + 1}")

# Compute steps for Prim's algorithm
steps = prim_mst_steps(G)

# Setup for animation
fig, ax = plt.subplots(figsize=(6, 6))
pos = nx.spectral_layout(G)

# Create the animation
ani = animation.FuncAnimation(fig, update_graph, frames=len(steps), fargs=(steps, pos, ax), interval=1000, repeat=False)
update_graph(0, steps, pos, ax)  # Directly go to the first frame

if False:
    # save the animation as mp4
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=1, metadata=dict(artist='Samuel Vaiter'), bitrate=1800)
    ani.save('prim-algorithm.mp4', writer=writer)

# Display the animation
plt.show()

### -------------------------------------------------------------------

import random
import numpy as np

# Generate a random weighted grid (maze) and apply Prim's algorithm
def create_grid_graph(size):
    G = nx.Graph()
    
    # Create a grid with weighted edges
    for i in range(size):
        for j in range(size):
            node = (i, j)
            if i < size - 1:
                weight = random.randint(1, 10)
                G.add_edge((i, j), (i + 1, j), weight=weight)
            if j < size - 1:
                weight = random.randint(1, 10)
                G.add_edge((i, j), (i, j + 1), weight=weight)
    
    return G

# Prim's algorithm to generate the maze
def prim_mst(graph):
    mst = nx.minimum_spanning_tree(graph, algorithm='prim')
    return mst

# Generate the maze structure from the MST
def generate_maze(graph, size):
    mst = prim_mst(graph)
    maze = np.ones((2 * size + 1, 2 * size + 1), dtype=int)

    for (u, v) in mst.edges:
        ux, uy = u
        vx, vy = v
        maze[ux + vx + 1][uy + vy + 1] = 0
        maze[ux * 2 + 1][uy * 2 + 1] = 0
        maze[vx * 2 + 1][vy * 2 + 1] = 0
    
    return maze

# Parameters for the maze
grid_size = 10
random_maze_graph = create_grid_graph(grid_size)
maze = generate_maze(random_maze_graph, grid_size)

# Generate the cumulative steps of the maze construction using Prim's algorithm
def generate_cumulative_maze_steps(graph, size):
    mst = prim_mst(graph)
    steps = []
    maze = np.ones((2 * size + 1, 2 * size + 1), dtype=int)

    for (u, v) in mst.edges:
        ux, uy = u
        vx, vy = v
        maze[ux + vx + 1][uy + vy + 1] = 0
        maze[ux * 2 + 1][uy * 2 + 1] = 0
        maze[vx * 2 + 1][vy * 2 + 1] = 0
        steps.append(maze.copy())

    return steps

# Generate the cumulative steps
cumulative_maze_steps = generate_cumulative_maze_steps(random_maze_graph, grid_size)

# Create the updated animation function
def update_cumulative_maze(num, steps, ax):
    ax.clear()
    ax.imshow(steps[num], cmap="binary", interpolation="none")
    ax.set_xticks([])
    ax.set_yticks([])
    # ax.set_title(f"Step {num + 1}")

# Setup for animation
fig, ax = plt.subplots(figsize=(8, 8))

# Create the animation
ani = animation.FuncAnimation(fig, update_cumulative_maze, frames=len(cumulative_maze_steps), fargs=(cumulative_maze_steps, ax), interval=30, repeat=False)

if True:
    # save the animation as mp4
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=10, metadata=dict(artist='Samuel Vaiter'), bitrate=1800)
    ani.save('prim-algorithm-maze.mp4', writer=writer)

# Display the animation
plt.show()