import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm, expon, uniform, powerlaw

# Function to choose a distribution
def get_distribution(name, size):
    if name == 'uniform':
        return np.random.uniform(low=-1, high=1, size=size)
    elif name == 'exponential':
        return np.random.exponential(scale=1, size=size)
    elif name == 'powerlaw':
        # Powerlaw distribution with shape parameter a=5
        return np.random.power(a=5, size=size)
    elif name == 'normal':
        return np.random.normal(loc=0, scale=1, size=size)
    else:
        raise ValueError(f"Unknown distribution: {name}")

# Parameters
n_samples = 10000  # Number of samples to generate for each experiment
n_experiments = [1, 2, 10, 100]  # Number of independent variables to average
distribution_name = 'exponential'  # Choose a distribution: 'uniform', 'exponential', 'powerlaw', 'normal'

# Prepare the plot
x = np.linspace(-2, 2, 1000)
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# Plot CDF of normal distribution for reference
ax.plot(x, norm.cdf(x), 'r--', lw=4, label='Normal CDF')

# For each experiment with different number of averages, calculate the empirical means
for n in n_experiments:
    # Generate the samples based on the chosen distribution
    samples = get_distribution(distribution_name, size=(n, n_samples))
    empirical_means = np.mean(samples, axis=0)  # Compute means
    
    # Normalize the empirical means to have 0 mean and unit variance
    normalized_means = (empirical_means - np.mean(empirical_means)) / np.std(empirical_means)
    
    # Calculate the empirical CDF
    empirical_cdf = np.array([np.mean(normalized_means <= xi) for xi in x])
    
    # Plot the empirical CDF
    ax.plot(x, empirical_cdf, lw=2, label=f'n={n}')

# Adding titles and labels
# ax.set_title(f'Empirical CDFs vs Normal CDF for {distribution_name.capitalize()} Distribution\nIllustrating Berry-Esseen Theorem', fontsize=14)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('CDF', fontsize=12)
ax.legend()

# Show the plot
plt.grid(True)
plt.show()

if True:
    fig.savefig("berry-essen.png")