import matplotlib.pyplot as plt
import numpy as np

# Create bimodal distributions for alpha and beta
n = 100
x = np.linspace(0, 1, n)
alpha = 0.2 * np.exp(-((x - 0.3) ** 2) / 0.01) + 0.5 * np.exp(-((x - 0.7) ** 2) / 0.02)  # Two bumps for alpha
beta = 0.5 * np.exp(-((x - 0.25) ** 2) / 0.015) + 0.5 * np.exp(-((x - 0.8) ** 2) / 0.01)  # Two bumps for beta

# Normalize both to represent probability densities
alpha /= np.sum(alpha)  # Sum of alpha should be 1
beta /= np.sum(beta)    # Sum of beta should be 1

# Add space between the two distributions
offset = 0.04

# Create the figure
plt.figure(figsize=(8, 6))

# Plot the source distribution (alpha) on the upper axis, with some space
plt.plot(x, alpha + offset, color='red', label=r'Source $\alpha$', linewidth=3)
plt.fill_between(x, alpha + offset, offset, color='red', alpha=0.3)
plt.text(1.02, offset + 0.02, r'$\alpha$', color='red', fontsize=15)

# Plot the target distribution (beta) on the lower axis, with some space
plt.plot(x, -(beta + offset), color='blue', label=r'Target $\beta$', linewidth=3)
plt.fill_between(x, -(beta + offset), -offset, color='blue', alpha=0.3)
plt.text(1.02, -(offset + 0.02), r'$\beta$', color='blue', fontsize=15)

# Plot the transport map (lines between points) using the CDFs
cdf_alpha = np.cumsum(alpha)  # CDF of source distribution
cdf_beta = np.cumsum(beta)    # CDF of target distribution

for i in range(0, n, 5):  # Plot fewer lines for clarity
    j = np.argmin(np.abs(cdf_beta - cdf_alpha[i]))  # Match CDFs
    # Plot the transport line with an arrow, including offset
    plt.annotate(
        '', 
        xy=(x[j], -(beta[j] + offset)), 
        xytext=(x[i], alpha[i] + offset), 
        arrowprops=dict(arrowstyle='->', color='magenta', alpha=0.6, linewidth=2)
    )

# Remove the axes for a cleaner look
plt.axis('off')
if True:
    plt.savefig("monge-problem.png", dpi=300)
plt.show()