import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Parameters
epsilon = 0.01
n_iter = 400
inner_cloud_size = 30
outer_cloud_size = 20

# Generate the inner and outer clouds
np.random.seed(0)
inner_cloud = np.random.uniform(-1, 1, (inner_cloud_size, 2)) * 0.5
inner_cloud = inner_cloud.astype(np.float64)

theta = np.random.uniform(0, 2 * np.pi, outer_cloud_size)
r = .8 + .2 * np.random.uniform(size=outer_cloud_size)
outer_cloud = np.vstack((r * np.cos(theta), r * np.sin(theta))).T
outer_cloud = outer_cloud.astype(np.float64)

# Compute the Euclidean cost matrix C
x2 = np.sum(inner_cloud**2, axis=1)
y2 = np.sum(outer_cloud**2, axis=1)
C = np.tile(y2, (inner_cloud_size, 1)) + np.tile(x2[:, np.newaxis], (1, outer_cloud_size)) - 2 * np.dot(inner_cloud, outer_cloud.T)

# Define the empirical histograms a and b
a = np.ones(inner_cloud_size) / inner_cloud_size
b = np.ones(outer_cloud_size) / outer_cloud_size

# Sinkhorn function with history tracking
def sinkhorn_with_history(C, a, b, epsilon, n_iter):
    u = np.ones_like(a)
    v = np.ones_like(b)
    K = np.exp(-C / epsilon)
    u_history, v_history, err_p, err_q = [], [], [], []

    for _ in range(n_iter):
        u_history.append(u.copy())
        v_history.append(v.copy())
        u = a / (K @ v)
        err_q.append(np.linalg.norm(v * np.dot(K.T, u) - b, 1))
        v = b / (K.T @ u)
        err_p.append(np.linalg.norm(u * np.dot(K, v) - a, 1))
    
    return K, u_history, v_history, err_p, err_q

# Run the Sinkhorn algorithm with history
K, u_history, v_history, err_p, err_q = sinkhorn_with_history(C, a, b, epsilon, n_iter)

# Create the figure and axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10))

# Initialize the transport plan plot
initial_P = np.diag(np.ones_like(a)) @ K @ np.diag(np.ones_like(b))
im = ax1.imshow(initial_P, animated=True, aspect='auto', cmap='viridis')
# ax1.set_title('Evolution of the Transport Plan')
ax1.set_xlabel('Outer Cloud Points')
ax1.set_ylabel('Inner Cloud Points')
# Set aspect ratio to 'equal' to ensure square pixels
ax1.set_aspect('equal')
# Add color bar
cbar = fig.colorbar(im, ax=ax1, orientation='vertical')
# cbar.set_label('Transport Plan Values')

# Prepare constraint violation plot
line1, = ax2.plot([], [], label=r'$\|Pu - a\|_1$', color='orange')
line2, = ax2.plot([], [], label=r'$\|P^T v - b\|_1$', color='purple')
ax2.set_yscale('log')
ax2.set_xlabel('Iteration')
ax2.set_ylabel('Constraint Violation')
ax2.set_xlim(0, n_iter)
ax2.set_ylim(1e-8, 1)
ax2.legend()
ax2.grid(True)
# ax2.set_title('Constraint Violations over Iterations')

# Animation function
def animate(i):
    u = u_history[i]
    v = v_history[i]
    P = np.diag(u) @ K @ np.diag(v)

    # Update transport plan plot by setting new data
    im.set_data(P)

    # Update constraint violation lines
    line1.set_data(range(i + 1), err_p[:i + 1])
    line2.set_data(range(i + 1), err_q[:i + 1])

    # Adjust limits
    ax2.relim()
    ax2.autoscale_view()

    return im, line1, line2

# Create animation
ani = animation.FuncAnimation(fig, animate, frames=n_iter, interval=10, blit=False)

if True:
    # save as mp4
    ani.save('sinkhorn.mp4', writer='ffmpeg', fps=30)

plt.tight_layout()
plt.show()
