from manim import *
import numpy as np

class ConstructMorleyTriangle(Scene):
    def construct(self):
        # Initial positions
        A_start = np.array([2, -1, 0])
        B_start = np.array([-2, 2, 0])
        C_start = np.array([3, 2, 0])

        # Radii for circular motion
        R_A = 1.0  # Radius for point A
        R_B = 0.8  # Radius for point B
        R_C = 0.6  # Radius for point C

        # Create moving dots for points A, B, C
        dot_A = Dot(point=A_start, color=BLUE)
        dot_B = Dot(point=B_start, color=BLUE)
        dot_C = Dot(point=C_start, color=BLUE)
        self.add(dot_A, dot_B, dot_C)

        # Labels for A, B, C that follow the dots
        label_A = MathTex("A").add_updater(lambda m: m.next_to(dot_A, DOWN))
        label_B = MathTex("B").add_updater(lambda m: m.next_to(dot_B, LEFT))
        label_C = MathTex("C").add_updater(lambda m: m.next_to(dot_C, RIGHT))
        self.add(label_A, label_B, label_C)

        # Triangle sides that update with the dots
        line_AB = always_redraw(lambda: Line(dot_A.get_center(), dot_B.get_center(), color=YELLOW))
        line_BC = always_redraw(lambda: Line(dot_B.get_center(), dot_C.get_center(), color=YELLOW))
        line_CA = always_redraw(lambda: Line(dot_C.get_center(), dot_A.get_center(), color=YELLOW))
        self.add(line_AB, line_BC, line_CA)

        # Function to get trisector directions at a vertex
        def get_trisector_directions(vertex_dot, adj_dot1, adj_dot2):
            vertex = vertex_dot.get_center()
            adj1 = adj_dot1.get_center()
            adj2 = adj_dot2.get_center()
            v1 = (adj1 - vertex) / np.linalg.norm(adj1 - vertex)
            v2 = (adj2 - vertex) / np.linalg.norm(adj2 - vertex)

            # Ensure vectors represent the internal angle
            if np.cross(v1[:2], v2[:2]) < 0:
                v1, v2 = v2, v1  # Swap to maintain correct orientation

            angle = np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0))

            # Create trisector directions
            trisector1_dir = rotate_vector(v1, angle / 3)
            trisector2_dir = rotate_vector(v1, 2 * angle / 3)

            return vertex, trisector1_dir, trisector2_dir

        # Trisector lines for each vertex
        trisector_line_A1 = Line(color=GREEN)
        trisector_line_A2 = Line(color=GREEN)
        trisector_line_B1 = Line(color=GREEN)
        trisector_line_B2 = Line(color=GREEN)
        trisector_line_C1 = Line(color=GREEN)
        trisector_line_C2 = Line(color=GREEN)

        # Update function for trisector lines and intersections
        def update_trisectors(mob):
            # Get trisector directions
            vertex_A, dir_A1, dir_A2 = get_trisector_directions(dot_A, dot_B, dot_C)
            vertex_B, dir_B1, dir_B2 = get_trisector_directions(dot_B, dot_C, dot_A)
            vertex_C, dir_C1, dir_C2 = get_trisector_directions(dot_C, dot_A, dot_B)

            # Extended lines for intersection calculation
            line_A1_ext = Line(vertex_A, vertex_A + dir_A1 * 10)
            line_A2_ext = Line(vertex_A, vertex_A + dir_A2 * 10)
            line_B1_ext = Line(vertex_B, vertex_B + dir_B1 * 10)
            line_B2_ext = Line(vertex_B, vertex_B + dir_B2 * 10)
            line_C1_ext = Line(vertex_C, vertex_C + dir_C1 * 10)
            line_C2_ext = Line(vertex_C, vertex_C + dir_C2 * 10)

            # Compute intersections
            intersection_1 = find_intersection(line_A2_ext, line_B1_ext)
            intersection_2 = find_intersection(line_B2_ext, line_C1_ext)
            intersection_3 = find_intersection(line_C2_ext, line_A1_ext)

            # Update trisector lines to end at intersections
            if intersection_1 is not None:
                trisector_line_A2.put_start_and_end_on(vertex_A, intersection_1)
                trisector_line_B1.put_start_and_end_on(vertex_B, intersection_1)
            else:
                trisector_line_A2.put_start_and_end_on(vertex_A, vertex_A)
                trisector_line_B1.put_start_and_end_on(vertex_B, vertex_B)

            if intersection_2 is not None:
                trisector_line_B2.put_start_and_end_on(vertex_B, intersection_2)
                trisector_line_C1.put_start_and_end_on(vertex_C, intersection_2)
            else:
                trisector_line_B2.put_start_and_end_on(vertex_B, vertex_B)
                trisector_line_C1.put_start_and_end_on(vertex_C, vertex_C)

            if intersection_3 is not None:
                trisector_line_C2.put_start_and_end_on(vertex_C, intersection_3)
                trisector_line_A1.put_start_and_end_on(vertex_A, intersection_3)
            else:
                trisector_line_C2.put_start_and_end_on(vertex_C, vertex_C)
                trisector_line_A1.put_start_and_end_on(vertex_A, vertex_A)

            # Update intersection dots
            if intersection_1 is not None:
                dot_I1.move_to(intersection_1)
            if intersection_2 is not None:
                dot_I2.move_to(intersection_2)
            if intersection_3 is not None:
                dot_I3.move_to(intersection_3)

            # Update Morley triangle
            if all(pt is not None for pt in [intersection_1, intersection_2, intersection_3]):
                morley_triangle.become(Polygon(
                    intersection_1, intersection_2, intersection_3,
                    color=RED, fill_opacity=0.3
                ))
            else:
                morley_triangle.become(VGroup())

        # Group all trisector lines and add updater
        trisector_lines = VGroup(
            trisector_line_A1, trisector_line_A2,
            trisector_line_B1, trisector_line_B2,
            trisector_line_C1, trisector_line_C2
        )
        trisector_lines.add_updater(update_trisectors)
        self.add(trisector_lines)

        # Intersection dots and Morley triangle
        dot_I1 = Dot(color=RED)
        dot_I2 = Dot(color=RED)
        dot_I3 = Dot(color=RED)
        morley_triangle = VGroup()
        self.add(dot_I1, dot_I2, dot_I3, morley_triangle)

        # Time parameter for animation
        t = ValueTracker(0)

        # Circular motion updaters for dots
        dot_A.add_updater(lambda m: m.move_to(A_start + np.array([
            R_A * np.cos(t.get_value()),
            R_A * np.sin(t.get_value()),
            0
        ])))
        dot_B.add_updater(lambda m: m.move_to(B_start + np.array([
            R_B * np.cos(t.get_value() + PI / 3),
            R_B * np.sin(t.get_value() + PI / 3),
            0
        ])))
        dot_C.add_updater(lambda m: m.move_to(C_start + np.array([
            R_C * np.cos(t.get_value() - PI / 3),
            R_C * np.sin(t.get_value() - PI / 3),
            0
        ])))

        # Animate the time parameter to move the dots
        self.play(t.animate.increment_value(2 * TAU), run_time=10, rate_func=linear)
        self.wait()

# Function to rotate a vector by an angle around the z-axis
def rotate_vector(v, angle):
    rotation_matrix = np.array([
        [np.cos(angle), -np.sin(angle), 0],
        [np.sin(angle),  np.cos(angle), 0],
        [0,              0,             1]
    ])
    return np.dot(rotation_matrix, v)

# Function to find the intersection of two lines
def find_intersection(line1, line2):
    line1_start, line1_end = line1.get_start_and_end()
    line2_start, line2_end = line2.get_start_and_end()

    x1, y1 = line1_start[:2]
    x2, y2 = line1_end[:2]
    x3, y3 = line2_start[:2]
    x4, y4 = line2_end[:2]

    denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
    if abs(denominator) < 1e-10:
        return None  # Lines are parallel or coincident

    px = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
    py = ((x1 * y2 - y1* x2) * (y3 - y4) - (y1 - y2)* (x3 * y4 - y3 * x4)) / denominator
    return np.array([px, py, 0])
