import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

# Define the vertices of the spherical triangle
vertices = np.array([
    [1, 0, 0],  # Point A
    [0, 1, 0],  # Point B
    [0, 0, 1]   # Point C
])

# Create a sphere
phi, theta = np.mgrid[0.0:2.0 * np.pi:100j, 0.0:np.pi:50j]
x = np.sin(theta) * np.cos(phi)
y = np.sin(theta) * np.sin(phi)
z = np.cos(theta)

# Plot the sphere
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Rotate the view to see the triangle
ax.view_init(elev=25, azim=6)

ax.plot_surface(x, y, z, color='c', alpha=0.3, rstride=5, cstride=5, linewidth=0)

# Function to interpolate points on a great circle
def interpolate_points(p1, p2, num_points=100):
    t = np.linspace(0, 1, num_points)
    p1 = np.array(p1)
    p2 = np.array(p2)
    omega = np.arccos(np.dot(p1, p2))
    sin_omega = np.sin(omega)
    return (np.sin((1 - t) * omega)[:, None] * p1 + np.sin(t * omega)[:, None] * p2) / sin_omega

# Plot the spherical triangle
for i in range(3):
    for j in range(i + 1, 3):
        points = interpolate_points(vertices[i], vertices[j])
        ax.plot(points[:, 0], points[:, 1], points[:, 2], color='r')

# Plot the vertices
ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], color='b', s=100)

# # Plot a big circle (equator)
# phi = np.linspace(0, 2 * np.pi, 100)
# theta = np.pi / 2
# x = np.cos(phi) * np.sin(theta)
# y = np.sin(phi) * np.sin(theta)
# z = np.cos(theta)
# ax.plot(x, y, z, color='g', linewidth=2)

# Plot another big circle (not parallel to equator)
phi = np.linspace(0, 2 * np.pi, 100)
theta = np.pi / 3  # 45 degrees from the z-axis
x = np.cos(phi) * np.sin(theta)
y = np.sin(phi) * np.sin(theta)
z = np.cos(2*theta)
ax.plot(x, y, z, color='m', linewidth=2)
# Function to find a circle passing through three points on a sphere
def circle_through_three_points(p1, p2, p3, num_points=100):
    # Find the normal vector to the plane defined by the three points
    normal = np.cross(p2 - p1, p3 - p1)
    normal /= np.linalg.norm(normal)
    
    # Find the center of the circle (intersection of the plane with the sphere)
    center = np.cross(np.cross(p1, p2), np.cross(p1, p3))
    center /= np.linalg.norm(center)
    
    # Find the radius of the circle
    radius = np.arccos(np.dot(center, p1))
    
    # Parametrize the circle
    t = np.linspace(0, 2 * np.pi, num_points)
    circle_points = np.zeros((num_points, 3))
    for i in range(num_points):
        circle_points[i] = (
            np.cos(t[i]) * p1 +
            np.sin(t[i]) * np.cross(normal, p1)
        )
        circle_points[i] /= np.linalg.norm(circle_points[i])
    
    return circle_points

# Select three random points on the sphere
np.random.seed(42)  # For reproducibility
random_points = np.random.randn(3, 3)
random_points = random_points / np.linalg.norm(random_points, axis=1)[:, None]

# Plot the circle passing through the three random points
circle_points = circle_through_three_points(random_points[0], random_points[1], random_points[2])
ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='m', linewidth=2)

# Set the aspect ratio to be equal
ax.set_box_aspect([1, 1, 1])

# Hide the axes
ax.set_axis_off()

# Save the figure with a transparent background
if True:
    fig.savefig('spherical-geometry.png', transparent=True)

plt.show()
