import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

# Lorenz system parameters
sigma = 10.0
rho = 28.0
beta = 8.0 / 3.0

# Lorenz system equations
def lorenz_deriv(X, t0, sigma=sigma, beta=beta, rho=rho):
    x, y, z = X
    dxdt = sigma * (y - x)
    dydt = x * (rho - z) - y
    dzdt = x * y - beta * z
    return np.array([dxdt, dydt, dzdt])

# Time points and initial conditions
dt = 0.01
t = np.arange(0.0, 25.0, dt)
initial_state = np.array([1.0, 1.0, 1.0])

# Numerically integrate the Lorenz equations using Euler's method
trajectory = np.empty((len(t), 3))
trajectory[0] = initial_state

for i in range(1, len(t)):
    trajectory[i] = trajectory[i - 1] + lorenz_deriv(trajectory[i - 1], t[i - 1]) * dt

# Set up the figure with a 3D plot and two 2D projections
fig = plt.figure(figsize=(10, 7))

# 3D plot for the Lorenz attractor
ax3d = fig.add_subplot(221, projection="3d")
ax3d.set_xlim((-25, 25))
ax3d.set_ylim((-35, 35))
ax3d.set_zlim((5, 55))
ax3d.set_xlabel("$x$")
ax3d.set_ylabel("$y$")
ax3d.set_zlabel("$z$")

# 2D plot for XY projection
ax_xy = fig.add_subplot(223)
ax_xy.set_xlim((-25, 25))
ax_xy.set_ylim((-35, 35))
ax_xy.set_xlabel("$x$")
ax_xy.set_ylabel("$y$")

# 2D plot for YZ projection
ax_yz = fig.add_subplot(222)
ax_yz.set_xlim((-35, 35))
ax_yz.set_ylim((5, 55))
ax_yz.set_xlabel("$y$")
ax_yz.set_ylabel("$z$")

# Initial plots for the evolving point and trajectory
point_3d, = ax3d.plot([], [], [], 'bo', markersize=5)
line_3d, = ax3d.plot([], [], [], lw=1)

point_xy, = ax_xy.plot([], [], 'bo', markersize=5)
line_xy, = ax_xy.plot([], [], lw=1)

point_yz, = ax_yz.plot([], [], 'bo', markersize=5)
line_yz, = ax_yz.plot([], [], lw=1)

# Initialization function
def init_dual():
    point_3d.set_data([], [])
    point_3d.set_3d_properties([])
    line_3d.set_data([], [])
    line_3d.set_3d_properties([])

    point_xy.set_data([], [])
    line_xy.set_data([], [])

    point_yz.set_data([], [])
    line_yz.set_data([], [])
    
    return point_3d, line_3d, point_xy, line_xy, point_yz, line_yz

# Update function for animation
def update_dual(frame):
    # 3D plot update
    point_3d.set_data(trajectory[frame, 0], trajectory[frame, 1])
    point_3d.set_3d_properties(trajectory[frame, 2])
    line_3d.set_data(trajectory[:frame, 0], trajectory[:frame, 1])
    line_3d.set_3d_properties(trajectory[:frame, 2])
    
    # XY projection update
    point_xy.set_data(trajectory[frame, 0], trajectory[frame, 1])
    line_xy.set_data(trajectory[:frame, 0], trajectory[:frame, 1])
    
    # YZ projection update
    point_yz.set_data(trajectory[frame, 1], trajectory[frame, 2])
    line_yz.set_data(trajectory[:frame, 1], trajectory[:frame, 2])
    
    return point_3d, line_3d, point_xy, line_xy, point_yz, line_yz

# Create the animation
ani_dual = FuncAnimation(fig, update_dual, frames=len(t), init_func=init_dual, interval=10, blit=True)

if True:
    ani_dual.save('lorenz_attractor.mp4', writer='ffmpeg', fps=60)

# Display the animation
plt.show()
