from manim import *

class PlotFunctionWithGradient(ThreeDScene):
    def construct(self):
        self.set_camera_orientation(phi=75 * DEGREES, theta=-45 * DEGREES)

        # Define the function f(x, y)
        def function(x, y):
            return np.sin(x) * np.cos(y)
        
        # Define the gradient
        def gradient(x, y):
            df_dx = np.cos(x) * np.cos(y)
            df_dy = -np.sin(x) * np.sin(y)
            return np.array([df_dx, df_dy])

        # Create the surface plot
        surface = Surface(
            lambda u, v: np.array([u, v, function(u, v)]),
            u_range=[-3, 3], v_range=[-3, 3],
            resolution=(50, 50),
            fill_opacity=1
        )

        # Manually color each vertex of the surface
        surface.set_color_by_xyz_func(lambda x, y, z: self.get_gradient_color(x, y, gradient))

        # Add the surface to the scene
        self.add(surface)

        # Draw gradient vectors
        self.draw_gradients(gradient)

        self.wait(2)
    
    def get_gradient_color(self, x, y, gradient_func):
        # Get gradient at (x, y)
        grad = gradient_func(x, y)
        magnitude = np.linalg.norm(grad)
        # Normalize the magnitude to get a value between 0 and 1
        color_value = min(max(magnitude / 2.0, 0), 1)
        # Map this to a color between BLUE and RED
        return interpolate_color(BLUE, RED, color_value)
    
    def draw_gradients(self, gradient_func):
        # Define the points where we want to draw gradients
        points = [(-2, -2), (0, 0), (2, 2)]
        for x, y in points:
            grad = gradient_func(x, y)
            grad_vector = Arrow3D(
                start=np.array([x, y, np.sin(x) * np.cos(y)]),
                end=np.array([x + grad[0], y + grad[1], np.sin(x + grad[0]) * np.cos(y + grad[1])]),
                color=YELLOW,
            )
            self.play(Create(grad_vector))
