import numpy as np
import matplotlib.pyplot as plt

# Function z = e^x * sin(y)
def z_function(x, y):
    return np.exp(x) * np.sin(y)

# Define a star-shaped domain (polar coordinates)
r_star = lambda theta: 1 + 0.3 * np.sin(5 * theta)  # Star shape function

# Generate points in polar coordinates
num_points = 2000
theta = np.linspace(0, 2 * np.pi, num_points)
r = r_star(theta)

# Convert polar coordinates to Cartesian coordinates
x_star = r * np.cos(theta)
y_star = r * np.sin(theta)

# Create a grid of points inside the star-shaped domain
grid_size = 1000
x_grid, y_grid = np.meshgrid(np.linspace(-1.5, 1.5, grid_size), np.linspace(-1.5, 1.5, grid_size))

# Mask points that are outside the star-shaped domain
r_grid = np.sqrt(x_grid**2 + y_grid**2)
theta_grid = np.arctan2(y_grid, x_grid)
mask = r_grid <= r_star(theta_grid)

# Apply the function only to points inside the star-shaped domain
z = np.full_like(x_grid, np.nan)  # Initialize with NaN to hide outside points
z[mask] = z_function(x_grid[mask], y_grid[mask])

# Plotting the star-shaped domain with the function
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')

# Draw the plane z = 0
x_plane, y_plane = np.meshgrid(np.linspace(-1.5, 1.5, grid_size), np.linspace(-1.5, 1.5, grid_size))
z_plane = np.zeros_like(x_plane)
ax.plot_surface(x_plane, y_plane, z_plane, color='gray', alpha=0.1, edgecolor='none')

# Draw the contour of the star-shaped domain
ax.plot(x_star, y_star, np.zeros_like(x_star), 'k--', linewidth=1)
# Plot the surface
ax.plot_surface(x_grid, y_grid, z, cmap='viridis', edgecolor='none', alpha=0.4)

# Draw the boundary of a circle of radius 1/2 and its projection onto the graph
# Circle of radius 1/2
circle_radius = 0.5
theta_circle = np.linspace(0, 2 * np.pi, 2000)
x_circle = circle_radius * np.cos(theta_circle)
y_circle = circle_radius * np.sin(theta_circle)
z_circle = z_function(x_circle, y_circle)

# Plot the circle boundary
ax.plot(x_circle, y_circle, np.zeros_like(x_circle), 'r--', label='Circle boundary (r=0.5)')
ax.plot_surface(x_circle, y_circle, np.zeros((2000,2000)), color='red', alpha=0.8)
ax.plot(x_circle, y_circle, z_circle, 'ro', markersize=2, label='Circle projection')

# Draw the point (0,0)
ax.scatter(0, 0, 0, color='blue', s=100, label='Point (0,0)')

# Calculate the maximum and minimum of the function on the circle
z_circle_max = np.max(z_circle)
z_circle_min = np.min(z_circle)

# Find the corresponding points on the circle
max_index = np.argmax(z_circle)
min_index = np.argmin(z_circle)
x_max, y_max = x_circle[max_index], y_circle[max_index]
x_min, y_min = x_circle[min_index], y_circle[min_index]

# Scatter the points showing the maximum and minimum
ax.scatter(x_max, y_max, z_circle_max, color='green', s=100, label='Max on circle')
ax.scatter(x_min, y_min, z_circle_min, color='purple', s=100, label='Min on circle')

# Set plot limits and labels
ax.set_xlim([-1.5, 1.5])
ax.set_ylim([-1.5, 1.5])
ax.set_zlim([np.nanmin(z), np.nanmax(z)])  # Ignore NaNs in limits
ax.set_xlabel('$x$')
ax.set_ylabel('$y$')
# ax.set_title('Function z = e^x sin(y) on a Star-shaped Domain')
ax.set_box_aspect([1,1,0.5])  # Aspect ratio is 1:1:0.5
ax.set_xticks([])  # Hide x-axis ticks
ax.set_yticks([])  # Hide y-axis ticks
ax.set_zticks([])  # Hide z-axis ticks
fig.colorbar(ax.plot_surface(x_grid, y_grid, z, cmap='viridis', edgecolor='none', alpha=0.4), shrink=0.5, aspect=10)

fig.tight_layout()

if True:
    fig.savefig('harmonic-function.png', dpi=300)

plt.show()
