import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.animation as animation

# Create a new graph
G = nx.complete_graph(7)

# Define BFS and DFS traversal
bfs_order = []
dfs_order = []

# BFS Traversal
def bfs(graph, start):
    visited = set()
    queue = [start]
    while queue:
        node = queue.pop(0)
        if node not in visited:
            visited.add(node)
            bfs_order.append(node)
            queue.extend(n for n in graph.neighbors(node) if n not in visited)

# DFS Traversal
def dfs(graph, start, visited=None):
    if visited is None:
        visited = set()
    visited.add(start)
    dfs_order.append(start)
    for neighbor in graph.neighbors(start):
        if neighbor not in visited:
            dfs(graph, neighbor, visited)

# Generate traversal orders
bfs(G, 0)
dfs(G, 0)

# Node positions for visualization
pos = {0: (0, 3), 1: (-1, 2), 2: (1, 2), 3: (-1.5, 1), 4: (-0.5, 1), 5: (0.5, 1), 6: (1.5, 1)}


# Animation function
def update(frame):
    plt.clf()

    # Left axis: BFS
    ax1 = plt.subplot(121)
    #ax1.set_title("Breadth-First Search")
    nx.draw(G, pos, ax=ax1, with_labels=True, node_color="lightgray", node_size=700, edge_color="gray", font_size=18)
    bfs_visited = bfs_order[:frame + 1] if frame < len(bfs_order) else bfs_order
    nx.draw_networkx_nodes(G, pos, ax=ax1, nodelist=bfs_visited, node_color="magenta", node_size=700)
    ax1.text(0.5, -0.1, f"Visited: {bfs_visited}", transform=ax1.transAxes, ha="center", fontsize=18)

    # Right axis: DFS
    ax2 = plt.subplot(122)
    #ax2.set_title("Depth-First Search")
    nx.draw(G, pos, ax=ax2, with_labels=True, node_color="lightgray", node_size=700, edge_color="gray", font_size=18)
    dfs_visited = dfs_order[:frame + 1] if frame < len(dfs_order) else dfs_order
    nx.draw_networkx_nodes(G, pos, ax=ax2, nodelist=dfs_visited, node_color="orange", node_size=700)
    ax2.text(0.5, -0.1, f"Visited: {dfs_visited}", transform=ax2.transAxes, ha="center", fontsize=18)

# Total frames: Maximum of BFS or DFS steps
num_frames = max(len(bfs_order), len(dfs_order))

# Create the animation
fig = plt.figure(figsize=(12, 6))
ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=1000, repeat=False)

if True:
    ani.save('bfs-vs-dfs.mp4', writer='ffmpeg', fps=1)

# Save or show the animation
plt.show()
