import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from skimage.restoration import denoise_tv_chambolle
from skimage import color, data

# Generate a noisy image
# f = color.rgb2gray(data.astronaut())
# noisy = f + 0.8 * f.std() * np.random.randn(*f.shape)
# Downsample the image and noise
# f = f[::2, ::2]
# noisy = noisy[::2, ::2]

f = data.brain()[3,...] / np.max(data.brain())
noisy = f + 1.4 * f.std() * np.random.randn(*f.shape)

# Parameters for animation
weights = np.linspace(0.001, 0.3, 20)
weights = np.logspace(-3,0,20)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for ax in axes:
    ax.axis('off')


# Display the original and noisy images
axes[0].imshow(f, cmap=plt.cm.gray, vmin=0, vmax=1)
# axes[0].set_title("Original", fontsize=16)
axes[1].imshow(noisy, cmap=plt.cm.gray, vmin=0, vmax=1)
# axes[1].set_title("Noisy", fontsize=16)
im = axes[2].imshow(noisy, cmap=plt.cm.gray, vmin=0, vmax=1)
#axes[2].set_title("$\lambda$ = 0.01", fontsize=16)
fig.tight_layout()

# Function to update the frame
def update(frame):
    weight = weights[frame]
    tv_denoised = denoise_tv_chambolle(noisy, weight=weight, eps=1e-8, max_num_iter=20000, channel_axis=None)
    im.set_data(tv_denoised)
    #axes[2].set_title(f"$\lambda$ = {weight:.3f}", fontsize=16)
    return [im]

# Create animation
ani = FuncAnimation(fig, update, frames=len(weights), interval=150)

if True:
    ani.save('rudin-osher-fatemi.mp4', writer='ffmpeg', fps=4)

# Save or display the animation
# plt.show()
