import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

np.random.seed(0)

# Initialize grid dimensions and random grid
grid_size = 40

def initialize_grid(grid_size):
    grid = np.zeros((grid_size, grid_size), dtype=int)
    
    # Patterns from the Still Lifes category
    block = np.array([[1, 1],
                      [1, 1]])
    beehive = np.array([[0, 1, 1, 0],
                        [1, 0, 0, 1],
                        [0, 1, 1, 0]])
    loaf = np.array([[0, 1, 1, 0],
                     [1, 0, 0, 1],
                     [0, 1, 0, 1],
                     [0, 0, 1, 0]])
    boat = np.array([[1, 1, 0],
                     [1, 0, 1],
                     [0, 1, 0]])
    tub = np.array([[0, 1, 0],
                    [1, 0, 1],
                    [0, 1, 0]])
    
    # Oscillators
    blinker = np.array([[1, 1, 1]])
    toad = np.array([[0, 1, 1, 1],
                     [1, 1, 1, 0]])
    beacon = np.array([[1, 1, 0, 0],
                       [1, 1, 0, 0],
                       [0, 0, 1, 1],
                       [0, 0, 1, 1]])
    pulsar = np.zeros((13, 13), dtype=int)
    pulsar[2, 4:9] = pulsar[4:9, 2] = pulsar[10, 4:9] = pulsar[4:9, 10] = 1
    pulsar[3, 3] = pulsar[3, 9] = pulsar[9, 3] = pulsar[9, 9] = 1
    pentadecathlon = np.array([[0, 1, 0],
                                [1, 1, 1],
                                [0, 1, 0],
                                [0, 1, 0],
                                [1, 1, 1],
                                [0, 1, 0]])

    # Spaceships
    glider = np.array([[0, 1, 0],
                       [0, 0, 1],
                       [1, 1, 1]])
    lwss = np.array([[0, 1, 1, 1, 1],
                     [1, 0, 0, 0, 1],
                     [0, 0, 0, 0, 1],
                     [1, 0, 0, 1, 0]])
    mwss = np.array([[0, 0, 1, 1, 1, 1],
                     [1, 1, 0, 0, 0, 1],
                     [0, 0, 0, 0, 0, 1],
                     [1, 0, 0, 0, 1, 0]])
    hwss = np.array([[0, 0, 1, 1, 1, 1, 1],
                     [1, 1, 0, 0, 0, 0, 1],
                     [0, 0, 0, 0, 0, 0, 1],
                     [1, 0, 0, 0, 0, 1, 0]])
    
    # Place the patterns on the grid
    grid[1:3, 1:3] = block
    grid[5:8, 1:5] = beehive
    grid[10:14, 1:5] = loaf
    grid[15:18, 1:4] = boat
    grid[20:23, 1:4] = tub
    grid[1:4, 10:13] = blinker
    grid[5:7, 10:14] = toad
    grid[10:14, 10:14] = beacon
    grid[15:28, 10:23] = pulsar
    grid[30:36, 10:13] = pentadecathlon
    grid[1:4, 30:33] = glider
    grid[5:9, 30:35] = lwss
    grid[10:14, 30:36] = mwss
    grid[15:19, 30:37] = hwss

    return grid


# Reinitialize the grid with the interesting configuration
grid = initialize_grid(grid_size)

# Function to update the grid based on the Game of Life rules
def update_grid(grid):
    new_grid = grid.copy()
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            # Get the 3x3 neighborhood
            total = np.sum(grid[i-1:i+2, j-1:j+2]) - grid[i, j]
            # Apply the rules of the Game of Life
            if grid[i, j] == 1:
                if total < 2 or total > 3:
                    new_grid[i, j] = 0  # Cell dies
            else:
                if total == 3:
                    new_grid[i, j] = 1  # Cell becomes alive
    return new_grid

# Recreate the animation with the new initialization
fig, ax = plt.subplots()
im = ax.imshow(grid, cmap='binary', interpolation='nearest')
ax.axis('off')


# Update function for animation
def animate(frame):
    global grid
    grid = update_grid(grid)
    im.set_data(grid)
    return [im]

# Reuse the animate function
anim = FuncAnimation(fig, animate, frames=100, interval=100, blit=True)

# save as mp4
if True:
    anim.save('basic-game-of-life.mp4', writer='ffmpeg', fps=10)

plt.show()
