import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# ===========================================
# Helper functions
# ===========================================
def soft_threshold(z, gamma):
    return np.sign(z) * np.maximum(np.abs(z) - gamma, 0.)

# Compute correlation of residual with features
def compute_correlation(X, r):
    return X.T @ r

# ===========================================
# LARS Implementation (for Lasso)
# ===========================================
def lars_lasso_path(X, y, max_steps=None, eps=1e-12):
    """
    Compute the LARS (Lasso) path.
    Returns:
    betas_path: array of shape (p, n_steps)
    alphas: regularization parameters (lambda values) at each step
    active_sets: the sequence of active sets chosen by LARS
    """
    n, p = X.shape
    if max_steps is None:
        max_steps = p * 10  # arbitrary large number
    
    # Centering X and y
    # In practice, we often center and scale the data.
    # For simplicity, we just center here.
    X_mean = X.mean(axis=0)
    y_mean = y.mean()
    Xc = X - X_mean
    yc = y - y_mean

    # Precompute norms for convenience
    norm_X = np.sqrt(np.sum(Xc**2, axis=0))

    # Initialize
    beta = np.zeros(p)
    mu = np.zeros(n)   # current prediction
    active_set = []
    is_active = np.zeros(p, dtype=bool)
    alphas = []
    betas_path = [beta.copy()]
    active_sets = [active_set.copy()]
    
    # Residual
    r = yc - mu
    c = compute_correlation(Xc, r)
    
    # Start LARS
    for step in range(max_steps):
        # Max absolute correlation
        C = np.max(np.abs(c[~is_active]))
        
        if C < eps:
            break
        
        # Add all variables with this correlation magnitude to active set
        candidates = np.where(np.abs(c) >= C - eps)[0]
        # Among candidates, only those not already active
        new_actives = [j for j in candidates if not is_active[j]]
        
        # Update active set and flags
        for j in new_actives:
            is_active[j] = True
            active_set.append(j)
        
        A = active_set
        A = list(map(int, A))  # ensure all indices are integers
        XA = Xc[:, A]          # should produce a (n, |A|) float array
        
        # Solve the least squares direction in the subspace of active variables
        # The direction we move in is the "equiangular direction":
        # We want a direction u = XA w so that all active corr have same angle
        G = XA.T @ XA
        # The direction chosen in LARS:
        # w = G^{-1} * sign(cA)
        # where cA are correlations of active vars
        cA = c[A]
        sign_cA = np.sign(cA)
        
        # Solve G w = sign_cA  (w is coefficients in the direction)
        w = np.linalg.lstsq(G, sign_cA)
        u = XA @ w  # direction in output space
        
        # Step length: find gamma so that next event occurs either:
        # 1. Another var enters with correlation
        # 2. An active var's coefficient hits zero
        a = (Xc.T @ u)
        
        # Variables not in A: check correlation crossing
        gamma_candidates = []
        
        # For correlation to drop from C to something else:
        # We solve |c_j - gamma * a_j| = C for j not in A
        for j in range(p):
            if not is_active[j]:
                aj = a[j]
                cj = c[j]
                gamma_j_plus = (C - cj) / (aj + 1e-20) if aj > 0 else np.inf
                gamma_j_minus = (C + cj) / (-aj + 1e-20) if aj < 0 else np.inf
                # Only consider positive steps
                candidates_j = [g for g in [gamma_j_plus, gamma_j_minus] if g > eps]
                if candidates_j:
                    gamma_candidates.append(min(candidates_j))
        
        # For active variables: check when a coefficient hits zero.
        # Current coefficients in direction:
        # Beta_new = beta_A + gamma * w
        # We want to find gamma so that any active coefficient goes zero:
        # If direction w_j is opposite sign of beta_j, we might go through zero.
        beta_A = beta[A]
        gamma_candidates_active = []
        for jj, j in enumerate(A):
            if w[jj] != 0:
                gamma_j = -beta_A[jj] / w[jj]
                if gamma_j > eps:
                    gamma_candidates_active.append(gamma_j)
        
        gamma_all = gamma_candidates + gamma_candidates_active
        
        if len(gamma_all) == 0:
            # If no further events, move infinitely in that direction
            gamma = 1e6
        else:
            gamma = np.min(gamma_all)
        
        # Update beta
        beta_prev = beta.copy()
        for jj, j in enumerate(A):
            beta[j] += gamma * w[jj]
        
        mu += gamma * u
        r = yc - mu
        c = compute_correlation(Xc, r)
        
        # Save state
        betas_path.append(beta.copy())
        alphas.append(C)
        active_sets.append(A.copy())
        
        # Stopping criterion: if all variables active, can stop
        if len(A) == p:
            break
    
    return np.array(betas_path).T, np.array(alphas), active_sets


# ===========================================
# Constructing Data with Non-Unique Lasso Solutions
# ===========================================
np.random.seed(0)

# For simplicity, let's create a scenario with perfectly correlated features:
X = np.array([[1,1],
              [2,2],
              [3,3]], dtype=float)
y = np.array([1,2,3], dtype=float)

# Here, any solution of form beta = (0.5, 0.5), (0.4,0.6), etc. that sums to 1 will fit perfectly
# The Lasso solution is not unique at lambda = 0, and along the path there will be a region
# where the set of solutions forms a line segment.


# ===========================================
# Run LARS to get the solution path
# ===========================================
betas_path, alphas, active_sets = lars_lasso_path(X, y)

print("Alphas (correlation values at steps):", alphas)
print("Active sets at each step:", active_sets)
print("Beta path:\n", betas_path)


# ===========================================
# Visualize the Set of Solutions
# ===========================================
# After the final step, let's look at the solution.
# Because X1=X2, the final solution path might show a region of non-unique solutions.

fig, ax = plt.subplots(figsize=(6,6))
ax.set_title("Lasso Path for Perfectly Correlated Features")
ax.set_xlabel(r'$\beta_1$')
ax.set_ylabel(r'$\beta_2$')

# Plot the entire path
ax.plot(betas_path[0,:], betas_path[1,:], marker='o', label='LARS Steps')
ax.grid(True)
ax.legend()

# We'll highlight that the set of solutions that minimize the error for lambda=0 
# is a line: beta_1 + beta_2 = 1. Let's overlay this line.
b1 = np.linspace(-0.5, 1.5, 100)
b2 = 1 - b1
ax.plot(b1, b2, 'r--', label='Line of minimal solutions (for lambda=0)')
ax.legend()

plt.show()


# ===========================================
# Animate the Iterates of LARS
# ===========================================
fig2, ax2 = plt.subplots(figsize=(6,6))
ax2.set_xlim(-0.1, 1.1)
ax2.set_ylim(-0.1, 1.1)
ax2.set_xlabel(r'$\beta_1$')
ax2.set_ylabel(r'$\beta_2$')
ax2.set_title("LARS Solution Path Animation")
ax2.grid(True)

line_sol, = ax2.plot([], [], 'o-', color='blue', lw=2)
ax2.plot(b1, b2, 'r--', label='Infinite Solutions at lambda=0')
ax2.legend()

def init_animation():
    line_sol.set_data([], [])
    return (line_sol,)

def animate(i):
    # Plot all points up to step i
    line_sol.set_data(betas_path[0, :i+1], betas_path[1, :i+1])
    return (line_sol,)

ani = FuncAnimation(fig2, animate, frames=betas_path.shape[1], 
                    init_func=init_animation, blit=True, interval=1000, repeat=False)

plt.show()
