import numpy as np
import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D  # noqa
from matplotlib.collections import PolyCollection
from matplotlib.animation import FuncAnimation

n = 100  # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)  # m= mean, s= std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
ax1.plot(x, A[:,0], color="orange")
ax1.plot(x, A[:,1], color="purple")
ax1.set_title('Distributions')

# Initialize alpha and weights
alpha = 0.0
weights = np.array([1 - alpha, alpha])

# wasserstein
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)

line, = ax2.plot(x, bary_wass, 'g', label='Wasserstein')
ax2.set_title('Barycenters')

def update(frame):
    alpha = frame / 100
    weights = np.array([1 - alpha, alpha])
    bary_wass = ot.bregman.barycenter(A, M, reg, weights)
    color = ((1 - alpha) * 1 + alpha * 0.5, (1 - alpha) * 0.5 + alpha * 0, (1 - alpha) * 0 + alpha * 0.5)  # Interpolating color from orange (1, 0.5, 0) to purple (0.5, 0, 0.5)
    line.set_color(color)
    line.set_ydata(bary_wass)
    return line,

ani = FuncAnimation(f, update, frames=np.linspace(0, 100, num=50), blit=True, repeat=True)

ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
plt.show()

# Save the animation as an mp4 file
ani.save('wasserstein_barycenter.mp4', writer='ffmpeg', fps=10)