import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Define the grid
res = 400  # Resolution of the grid
x = np.linspace(-1.5, 1.5, res)
y = np.linspace(-1.5, 1.5, res)
X, Y = np.meshgrid(x, y)

# Signed distance function of a circle (radius = 1)
def sdf_circle(X, Y):
    return np.sqrt(X**2 + Y**2) - 1

# Signed distance function of a square (edge length = 2)
def sdf_square(X, Y):
    dx = np.abs(X) - 1
    dy = np.abs(Y) - 1
    inside = np.maximum(dx, dy)  # Interior distance
    outside = np.hypot(np.maximum(dx, 0), np.maximum(dy, 0))  # Exterior distance
    return np.where(inside > 0, outside, inside)

def smooth_min(d1, d2, k):
    return -np.log(np.exp(-d1 / k) + np.exp(-d2 / k)) * k

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)

# Initialize the contour plot
contour = ax.contourf(X, Y, np.zeros((res, res)), levels=np.linspace(-1, 1, 50), cmap="coolwarm")

# Animation function
def update(frame):
    alpha = frame / 100  # Alpha blending factor
    # alpha blending
    sdf_blended = (1 - alpha) * sdf_circle(X, Y) + alpha * sdf_square(X, Y)
    cont = ax.contourf(X, Y, sdf_blended, levels=20, cmap="coolwarm")
    return cont

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=101, interval=10)

# Save the animation
ani.save("signed-distance-function-blend.mp4", fps=30, dpi=300)

# Display animation
plt.show()