import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Generate a random positive x vector in R^10
np.random.seed(42)  # For reproducibility
x = np.random.rand(10) * 2  # Random positive values in R^10

# Range of epsilon values
epsilons = np.logspace(0, -2, 100)  # From 1 to 0.001

# Compute softmax for a given epsilon
def softmax_epsilon(x, epsilon):
    exp_x = np.exp(x / epsilon)
    return exp_x / np.sum(exp_x)

# Create figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Initialize the bar plot and line plot
bars = ax.bar(range(1, len(x) + 1), softmax_epsilon(x, epsilons[0]), color='blue', alpha=0.7, label='$\sigma_{\\varepsilon}(x)$')
line, = ax.plot(range(1, len(x) + 1), x, 'k--o', label='$x$', markersize=10)
# Highlight maximum value in red
x_max_idx = np.argmax(x)
ax.plot(x_max_idx + 1, x[x_max_idx], 'ro', markersize=15)
# draw a line at y = 1
ax.axhline(y=1, color='gray', linestyle='--', alpha=0.7)

# Title and axis labels
#ax.set_title('Softmax with Temperature', fontsize=16)
ax.set_xlabel('$i$', fontsize=14)
ax.set_ylim(0, 2)
ax.legend()

# Update function for the animation
def update(frame):
    epsilon = epsilons[frame]
    softmax_vals = softmax_epsilon(x, epsilon)
    for bar, val in zip(bars, softmax_vals):
        bar.set_height(val)
    ax.set_title(f'$\epsilon$={epsilon:.3f}', fontsize=16)
    return bars

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=len(epsilons), interval=50, blit=False)

# Save animation as MP4 file
ani.save('softmax-temperature.mp4', writer='ffmpeg', fps=30)

plt.show()