from manim import *

class LpRegularizationIntersection(Scene):
    def construct(self):
        # Create axes with equal scaling
        axes = Axes(
            x_range=[-3, 3], y_range=[-3, 3], 
            axis_config={"color": GRAY},
            x_length=6,  # Set equal length for both axes
            y_length=6,  # To ensure 1:1 aspect ratio
        )
        labels = axes.get_axis_labels(x_label="x_1", y_label="x_2")
        self.play(Create(axes), Write(labels))  # Animating axes and labels

        # Non-axis-aligned line slightly unaligned with the diamond
        # For example, use a slope of 1.2 instead of the exact diagonal of the diamond
        tangent_line = Line(start=axes.c2p(-3, 1.2 * -3), end=axes.c2p(3, 1.2 * 3), color=YELLOW)
        self.play(Create(tangent_line))

        # L2 Ball (circle) growing from radius 1 to radius sqrt(2)
        l2_circle_initial = Circle(radius=1, color=BLUE).move_to(axes.c2p(0, 0))
        l2_circle_final = Circle(radius=1.41, color=BLUE).move_to(axes.c2p(0, 0))
        l2_label = Tex(r"$\|w\|_2$", font_size=36).next_to(l2_circle_final, DOWN)
        
        self.play(Create(l2_circle_initial))  # Animate the creation of the initial circle
        self.play(Transform(l2_circle_initial, l2_circle_final))  # Animate growth
        self.play(Write(l2_label))  # Animate the label for L2

        # L1 Ball (diamond) growing from radius 1 to radius 2
        l1_diamond_initial = Polygon(
            axes.c2p(-1, 0), axes.c2p(0, 1), axes.c2p(1, 0), axes.c2p(0, -1),
            color=GREEN
        )
        l1_diamond_final = Polygon(
            axes.c2p(-2, 0), axes.c2p(0, 2), axes.c2p(2, 0), axes.c2p(0, -2),
            color=GREEN
        )

        l1_label = Tex(r"$\|w\|_1$", font_size=36).next_to(l1_diamond_final, DOWN)
        
        self.play(Create(l1_diamond_initial))  # Animate the initial diamond
        self.play(Transform(l1_diamond_initial, l1_diamond_final))  # Animate growth of diamond
        self.play(Write(l1_label))  # Animate the label for L1

        # Final hold for viewing
        self.wait(3)
