import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def f(x, y):
    return -(x**2 + y**2 - 1)*(x**2 - y**2 - 1)
    #return (x**2 + y**2 - 1)*(x**3 - 3*x*(y**2) - x) / 3.

# Define the Łojasiewicz bound for the 2D function
def lojasiewicz_bound_2d(x, y, C, alpha):
    return C * np.abs(f(x, y))**alpha

# Define the zero locus of the function where f(x, y) = 0
def zero_locus_2d(x, y):
    return np.isclose(f(x, y), 0, atol=1e-2)  # Using a small tolerance for numeric stability

# Create a meshgrid for 2D plotting
x_values_2d, y_values_2d = np.meshgrid(np.linspace(-2, 2, 1000), np.linspace(-2, 2, 1000))
f_complex_values_2d = f(x_values_2d, y_values_2d)

# Set constants for the Łojasiewicz inequality (doigt mouillé!)
C = 1
alpha = 0.5

# Compute the Łojasiewicz bound for the 2D function
lojasiewicz_values_2d = lojasiewicz_bound_2d(x_values_2d, y_values_2d, C, alpha)

# Get the points on the zero locus
zero_locus_points = zero_locus_2d(x_values_2d, y_values_2d)

# Plot the 3D surface with Łojasiewicz inequality and the zero locus
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=51, azim=-45)

# Plotting the 3D surface of the function f(x, y)
surf = ax.plot_surface(x_values_2d, y_values_2d, f_complex_values_2d, cmap='viridis', edgecolor='none', alpha=0.8)

# Plot the Łojasiewicz inequality contours
contour = ax.contour(x_values_2d, y_values_2d, lojasiewicz_values_2d, zdir='z', offset=-5, levels=10, colors='red', linestyles='--')

# Add the zero locus in blue
ax.scatter(x_values_2d[zero_locus_points], y_values_2d[zero_locus_points], np.zeros_like(x_values_2d[zero_locus_points]), 
           color='blue', label='Zero Locus')

# Labels and title
# ax.set_title("3D Surface Plot with Łojasiewicz Inequality and Zero Locus")
# ax.set_xlabel('x')
# ax.set_ylabel('y')
# ax.set_zlabel('f(x, y)')

# Show colorbar for the surface plot
# fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)

# plt.legend()
if True:
    plt.savefig("lojasiewicz-inequality.png", dpi=300)
plt.show()
