import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def histogram_equalization(image_np):
    """ Perform histogram equalization on a grayscale image """
    # Compute the histogram of the original image
    hist, bins = np.histogram(image_np.flatten(), bins=256, range=[0, 256], density=True)
    cdf = hist.cumsum()
    cdf_normalized = cdf * 255 / cdf.max()

    # Use interpolation to map the original intensities to the equalized ones
    image_equalized = np.interp(image_np.flatten(), bins[:-1], cdf_normalized).reshape(image_np.shape)
    return image_equalized.astype(np.uint8)

def create_animation(reference_image, equalized_image, num_frames=30):
    """ Create an animation showing linear interpolation between reference and equalized images and their bar histograms """
    # Create a figure and axes for both the image and the bar plot
    fig, (ax_img, ax_hist) = plt.subplots(1, 2, figsize=(12, 5))

    # Sorting the reference and equalized images as required
    ref_sorted = np.sort(reference_image.flatten())
    eq_sorted = np.sort(equalized_image.flatten())
    image_anim = reference_image.flatten().copy()

    # Prepare the image display and bar plot
    img_display = ax_img.imshow(reference_image, cmap='gray', animated=True)
    bars = ax_hist.bar(np.arange(256), np.zeros(256), color='gray')

    ax_hist.set_xlim(0, 255)
    ax_hist.set_ylim(0, np.histogram(reference_image, bins=256, range=[0, 256])[0].max())
    ax_hist.set_title('Histogram (Bar Plot)')

    def update_frame(frame):
        # Compute the interpolated image for the current frame
        alpha = frame / num_frames
        interpolated_image = (1 - alpha) * ref_sorted + alpha * eq_sorted
        image_anim[np.argsort(reference_image.flatten())] = interpolated_image

        # Reshape back to the original image shape and update the display
        img_display.set_array(image_anim.reshape(reference_image.shape))
        ax_img.set_title(f'Interpolation: Frame {frame + 1}/{num_frames}')

        # Update the bar plot with the new histogram values
        hist_vals, _ = np.histogram(image_anim, bins=256, range=[0, 256])
        for bar, h in zip(bars, hist_vals):
            bar.set_height(h)

        return [img_display] + list(bars)

    # Create the animation using the update function
    ani = animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=100, blit=True)

    # Display the animation
    plt.tight_layout()
    plt.show()

# Load a reference grayscale image
reference_image = np.array(Image.open('data/boat-original.png').convert('L'))

# Perform histogram equalization on the reference image
equalized_image = histogram_equalization(reference_image)

# Create and display the animation
create_animation(reference_image, equalized_image)
