import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.gridspec as gridspec

# Singular perturbation parameter (small epsilon)
epsilon = 0.1

# Define the system of ODEs
def van_der_pol(t, state):
    x, y = state
    dx_dt = y
    dy_dt = (1 - x**2) * y - x
    return [dx_dt, dy_dt / epsilon]  # Dividing by epsilon for the fast dynamics

# Initial conditions: [x, y]
initial_conditions = [3, 0]  # Start at some initial condition

# Time span for the simulation (start time, end time)
tmax = 10
t_span = (0, tmax)  # Simulate over 25 seconds
t_eval = np.linspace(0, tmax, 2000)  # Time points for evaluation

# Solve the system
solution = solve_ivp(van_der_pol, t_span, initial_conditions, method='RK45', t_eval=t_eval)

# Create a figure with gridspec for layout control
fig = plt.figure(figsize=(10, 8))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 2], height_ratios=[1, 1])

# Plot x (slow variable) vs time in the first small plot (top-left)
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(solution.t, solution.y[0], label='x (slow variable)', color='blue')
current_x, = ax1.plot([], [], 'ro')  # Red point for the current x position
ax1.set_title('Slow Dynamics (x vs Time)')
ax1.set_ylabel('x')
ax1.set_xlim(t_span)
ax1.grid(True)

# Plot y (fast variable) vs time in the second small plot (bottom-left)
ax2 = fig.add_subplot(gs[1, 0])
ax2.plot(solution.t, solution.y[1], label='y (fast variable)', color='orange')
current_y, = ax2.plot([], [], 'ro')  # Red point for the current y position
ax2.set_title('Fast Dynamics (y vs Time)')
ax2.set_ylabel('y')
ax2.set_xlim(t_span)
ax2.grid(True)

# Phase plot: x vs y in the large plot (right side)
ax3 = fig.add_subplot(gs[:, 1])  # Take up both rows on the right
ax3.plot(solution.y[0], solution.y[1], label='Phase plot (x vs y)', color='green')
point, = ax3.plot([], [], 'ro')  # Red point to animate the current x and y position
ax3.set_title('Phase Plot: Van der Pol Oscillator')
ax3.set_xlabel('x (slow variable)')
ax3.set_ylabel('y (fast variable)')
ax3.grid(True)

# Animation update function
def update(frame):
    # Update the point on the phase plot
    point.set_data(solution.y[0][frame], solution.y[1][frame])
    
    # Update the current point on the x (slow variable) plot
    current_x.set_data(solution.t[frame], solution.y[0][frame])
    
    # Update the current point on the y (fast variable) plot
    current_y.set_data(solution.t[frame], solution.y[1][frame])
    
    return point, current_x, current_y

# Create the animation
ani = FuncAnimation(fig, update, frames=len(solution.t), interval=10, blit=True)

# Adjust layout and display the animation
plt.tight_layout()
plt.show()
