import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

x = np.linspace(0, 1, 500)

# Define functions using JAX
def jax_cubic_polynomial_smoothstep(x):
    return x * x * (3.0 - 2.0 * x)

def jax_quartic_polynomial_smoothstep(x):
    return x * x * (2.0 - x * x)

def jax_quintic_polynomial_smoothstep(x):
    return x * x * x * (x * (x * 6.0 - 15.0) + 10.0)

def jax_quadratic_rational_smoothstep(x):
    return x * x / (2.0 * x * x - 2.0 * x + 1.0)

def jax_cubic_rational_smoothstep(x):
    return x * x * x / (3.0 * x * x - 3.0 * x + 1.0)

def jax_trigonometric_smoothstep(x):
    return 0.5 - 0.5 * jnp.cos(jnp.pi * x)

jax_functions = [
    (jax_cubic_polynomial_smoothstep, "Cubic Polynomial"),
    (jax_quartic_polynomial_smoothstep, "Quartic Polynomial"),
    (jax_quintic_polynomial_smoothstep, "Quintic Polynomial"),
    (jax_quadratic_rational_smoothstep, "Quadratic Rational"),
    (jax_cubic_rational_smoothstep, "Cubic Rational"),
    (jax_trigonometric_smoothstep, "Trigonometric"),
]

# Define equations for each smoothstep function
equations = {
    jax_cubic_polynomial_smoothstep: r"$f(x)=x^2(3-2x)$",
    jax_quartic_polynomial_smoothstep: r"$f(x)=x^2(2-x^2)$",
    jax_quintic_polynomial_smoothstep: r"$f(x)=x^3(6x^2-15x+10)$",
    jax_quadratic_rational_smoothstep: r"$f(x)=\frac{x^2}{2x^2-2x+1}$",
    jax_cubic_rational_smoothstep: r"$f(x)=\frac{x^3}{3x^2-3x+1}$",
    jax_trigonometric_smoothstep: r"$f(x)=\frac{1}{2}-\frac{1}{2}\cos(\pi x)$"
}

# Set up the plot with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Add diagonal line y=x
ax1.plot([0, 1], [0, 1], 'k:', linewidth=0.5, alpha=0.5)

# Configure subplot 1: smoothstep and its inverse
line1, = ax1.plot([], [], linewidth=4, color="#17becf", label="Smoothstep")
line2, = ax1.plot([], [], linewidth=4, color="#9467bd", label="Inverse", linestyle="--")
ax1.legend()
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.set_title("Smoothstep and Inverse")
ax1.set_xlabel("$x$")

# Configure subplot 2: first and second derivatives
line3, = ax2.plot([], [], linewidth=4, color="#e377c2", label="First Derivative")
line4, = ax2.plot([], [], linewidth=4, color="#bcbd22", label="Second Derivative", linestyle="--")
ax2.legend()
ax2.set_xlim(0, 1)
ax2.set_ylim(-10, 10)  # Adjust range to accommodate derivatives
ax2.set_title("Derivatives of Smoothstep")
ax2.set_xlabel("$x$")

# Update function for animation
def update_with_jax_derivatives(frame):
    smoothstep_fn, title = jax_functions[frame]
    
    # Smoothstep
    y_smoothstep = smoothstep_fn(x)
    line1.set_data(x, y_smoothstep)

    # Compute inverse by swapping x and y, then sorting
    x_inverse = y_smoothstep
    y_inverse = x
    sort_idx = np.argsort(x_inverse)
    line2.set_data(x_inverse[sort_idx], y_inverse[sort_idx])

    # Derivatives using JAX
    dy_dx_fn = jax.grad(smoothstep_fn)  # First derivative
    d2y_dx2_fn = jax.grad(dy_dx_fn)    # Second derivative

    # Vectorize the gradient functions
    dy_dx_fn_vectorized = jax.vmap(dy_dx_fn)
    d2y_dx2_fn_vectorized = jax.vmap(d2y_dx2_fn)

    dy_dx = dy_dx_fn_vectorized(x)
    d2y_dx2 = d2y_dx2_fn_vectorized(x)

    line3.set_data(x, dy_dx)
    line4.set_data(x, d2y_dx2)
    
    equation = equations[smoothstep_fn]
    ax1.set_title(f"{title}: {equation}")
    ax2.set_title("1st & 2nd derivatives")
    return line1, line2, line3, line4

# Create animation
ani_jax = FuncAnimation(fig, update_with_jax_derivatives, frames=len(jax_functions), interval=2000, blit=False)

plt.tight_layout()

if True:
    ani_jax.save('smoothstep.mp4', writer='ffmpeg', fps=1)

plt.show()
