from manim import *
import math

class AnimateFenchelConjugate(Scene):
    def construct(self):
        # Setup Axes
        axes = Axes(
            x_range=[-2, 2, 1],
            y_range=[-4, 4, 1],
            axis_config={"color": WHITE},
        )

        # Labels for axes
        axes_labels = axes.get_axis_labels(x_label="x",y_label="")

        # Define the function f(x)
        def f(x):
            return 0.5 * x**3 - 2 * x

        # Define the derivative of f(x) to find the slope at any point
        def f_prime(x):
            return 1.5 * x**2 - 2

        # Define the line y = xy (tangent line)
        def tangent_line_through_conjugate_point(x0):
            slope = f_prime(x0)  # This is the slope of the tangent line
            f_star_y_value = slope * x0 - f(x0)  # This is the Fenchel conjugate value f*(y)
            return lambda x: slope * x + (-f_star_y_value)  # Equation of the tangent line passing through (0, -f*(y))

        # Plot the function f(x)
        function_graph = axes.plot(f, color=BLUE, x_range=[-2, 3])

        # Create a ValueTracker for the x-coordinate of the tangent point
        x_tracker = ValueTracker(1.0)

        # Tangent line and dot for point of tangency
        tangent_line = always_redraw(lambda: axes.plot(
            tangent_line_through_conjugate_point(x_tracker.get_value()), color=YELLOW, x_range=[-2, 3])
        )
        tangent_dot = always_redraw(lambda: Dot(
            axes.c2p(x_tracker.get_value(), f(x_tracker.get_value())), color=RED)
        )

        # Vertical line connecting to conjugate point on y-axis
        vertical_line = always_redraw(lambda: DashedLine(
            start=axes.c2p(0, -(f_prime(x_tracker.get_value()) * x_tracker.get_value() - f(x_tracker.get_value()))),
            end=axes.c2p(x_tracker.get_value(), f(x_tracker.get_value())),
            color=ORANGE
        ))
        
        def fenchel_conjugate(y):
            # Compute the Fenchel conjugate based on the derived formula
            term1 = (y + 2) * math.sqrt(2 * (y + 2)) / math.sqrt(3)
            term2 = math.sqrt(2 * (y + 2)**3) / (3 ** (3 / 2))
            return term1 - term2

        fenchel_graph = always_redraw(lambda: axes.plot(
            fenchel_conjugate, color=PURPLE, x_range=[-2, 2]
        ))

        # Conjugate point at (0, -f*(y))
        conjugate_point = always_redraw(lambda: Dot(
            axes.c2p(0, -(f_prime(x_tracker.get_value()) * x_tracker.get_value() - f(x_tracker.get_value()))), color=GREEN)
        )
        conjugate_label = always_redraw(lambda: MathTex(
            r"(0, -f^*(y))").next_to(conjugate_point, LEFT)
        )

        # Corresponding point on the Fenchel conjugate curve
        fenchel_dot = always_redraw(lambda: Dot(
            axes.c2p(f_prime(x_tracker.get_value()), fenchel_conjugate(f_prime(x_tracker.get_value()))), color=PURPLE)
        )
        fenchel_dot_label = always_redraw(lambda: MathTex(
            r"f^*(y)").next_to(fenchel_dot, UP)
        )

        # Line y = x
        line_y_equals_x = always_redraw(lambda: axes.plot(
            lambda x: f_prime(x_tracker.get_value()) * x, color=GREEN, x_range=[-2, 2])
        )

        self.add(line_y_equals_x)
        xy_label = always_redraw(lambda: MathTex(r"x \mapsto xy").next_to(line_y_equals_x.get_end(), LEFT, buff=0.1))
        self.add(xy_label)

        # Animate the x_tangent moving along the x-axis
        self.add(axes, axes_labels)
        self.add(function_graph)

        # Create the tangent line, point of tangency, vertical line, and conjugate point
        self.add(tangent_line, tangent_dot)
        self.add(vertical_line, conjugate_point, conjugate_label)

        # Add labels for the function and tangent
        function_label = always_redraw(lambda: MathTex(
            r"f(x)").next_to(tangent_dot, UP, buff=0.3)
        )
        self.add(function_label)

        # Add the Fenchel conjugate function to the plot
        self.add(fenchel_graph)

        # Add the moving point on the Fenchel conjugate curve
        self.add(fenchel_dot, fenchel_dot_label)

        fenchel_conjugate_label = always_redraw(lambda: MathTex(
            r"f^*(y) = \sup_x \{xy - f(x)\}"
        ).next_to(axes, DOWN, buff=0.1))
        self.add(fenchel_conjugate_label)

        # Animate the x_tracker
        self.play(
            x_tracker.animate.set_value(1.5),
            run_time=6,
            rate_func=there_and_back
        )

        self.wait(3)
