import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset

def weierstrass_function(x, a, b, n_terms):
    result = np.zeros_like(x)
    for n in range(n_terms):
        result += a**n * np.cos(b**n * np.pi * x)
    return result

# Parameters
a = 0.5
b = 3
n_terms = 100
x = np.linspace(-2, 2, 100000)

# Compute Weierstrass function
y = weierstrass_function(x, a, b, n_terms)

# Plot the Weierstrass function
fig, ax = plt.subplots(figsize=(10, 8))

ax.plot(x, y, label='Weierstrass Function')
# ax.set_title('Weierstrass Function')
# ax.legend()

# Zoomed-in plot
zoom_start, zoom_end = -0.1, 0.1
zoom_indices = (x >= zoom_start) & (x <= zoom_end)

# Create inset of the zoomed-in region
axins = inset_axes(ax, width="40%", height="40%", loc='upper right')
axins.plot(x[zoom_indices], y[zoom_indices], color='orange')
axins.set_xlim(zoom_start, zoom_end)
axins.set_ylim(y[zoom_indices].min(), y[zoom_indices].max())
axins.set_xticks([])
axins.set_yticks([])
# axins.set_title('Zoomed-In Weierstrass Function')

# Connect the inset with the main plot
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")

# Zoomed-in plot inside the zoomed-in plot
zoom_start_inner, zoom_end_inner = -0.02, 0.02
zoom_indices_inner = (x >= zoom_start_inner) & (x <= zoom_end_inner)

# Create inset of the zoomed-in region inside the first zoom
axins2 = inset_axes(axins, width="40%", height="40%", loc='lower center')
axins2.plot(x[zoom_indices_inner], y[zoom_indices_inner], color='red')
axins2.set_xlim(zoom_start_inner, zoom_end_inner)
axins2.set_ylim(y[zoom_indices_inner].min(), y[zoom_indices_inner].max())
axins2.set_xticks([])
axins2.set_yticks([])

# Connect the inset with the main plot
mark_inset(axins, axins2, loc1=2, loc2=4, fc="none", ec="0.2")

plt.tight_layout()
plt.show()

if True:
    fig.savefig("weierstrass.png", dpi=300)