import numpy as np
import matplotlib.pyplot as plt

# Define alpha values from 0 to 1
alpha_values = np.linspace(0, 1, 500)

# Calculate entropy for each alpha
entropy_values = -(alpha_values * np.log2(alpha_values, where=alpha_values!=0) + 
                   (1 - alpha_values) * np.log2(1 - alpha_values, where=(1 - alpha_values)!=0))

# Plot the entropy as a function of alpha with specified ticks
plt.figure(figsize=(4, 3))
plt.plot(alpha_values, entropy_values, label='Entropy H(X)', color='blue')
plt.axvline(x=0.5, color='black', linestyle='--', label=r'$\alpha = 0.5$')
plt.axhline(y=1, color='black', linestyle='--', label=r'Max Entropy (H(X) = 1)')
plt.plot(0.5, 1, 'go')  # Mark the maximum point on the curve

# Set specific ticks
plt.xticks([0.0, 0.5, 1.0])
plt.yticks([0.0, 0.5, 1.0])

# plt.xlabel(r'$\alpha$')
# plt.ylabel('Entropy H(X)')
# plt.title(r'Entropy $H(X)$ as a Function of $\alpha$ for $X = \{0,1\}$')
# plt.grid(True)
plt.tight_layout()
# plt.legend()
if True:
    plt.savefig('entropy-two-elements.png', dpi=300)
plt.show()
