import matplotlib.pyplot as plt

def draw_mlp(layers, figsize=(10, 6)):
    fig, ax = plt.subplots(figsize=figsize)

    # Define the positions of the nodes
    positions = []
    for i, num_neurons in enumerate(layers):
        layer_positions = [(i * 2, y) for y in [1 - (2 * j + 1) / (2 * num_neurons) for j in range(num_neurons)]]
        positions.append(layer_positions)

    # Draw the nodes and ReLU boxes
    for i, layer in enumerate(positions):
        for (x, y) in layer:
            circle = plt.Circle((x, y), 0.05, fill=True, color='blue')
            ax.add_artist(circle)
            if 0 < i < len(layers) - 1:
                # Draw ReLU box
                box_x = x + 0.5
                box = plt.Circle((box_x, y), 0.05, fill=True, color='blue')
                ax.add_artist(box)

    # Draw the arrows
    for i in range(len(positions) - 1):
        for (x1, y1) in positions[i]:
            if 0 < i < len(positions) - 1:
                box_x = x1 + 0.5
                for (x2, y2) in positions[i + 1]:
                    ax.arrow(x1 + 0.05, y1, box_x - x1 - 0.05, 0, head_width=0.02, head_length=0.02, fc='black', ec='black')
                    ax.arrow(box_x + 0.05, y1, x2 - box_x - 0.05, y2 - y1, head_width=0.02, head_length=0.02, fc='black', ec='black')
            else:
                for (x2, y2) in positions[i + 1]:
                    ax.arrow(x1 + 0.05, y1, x2 - x1 - 0.1, y2 - y1, head_width=0.02, head_length=0.02, fc='black', ec='black')

    # Set the limits and aspect ratio
    ax.set_xlim(-0.1, (len(layers) - 1) * 2 + 0.1)
    ax.set_ylim(0, 1)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.tight_layout()

    if True:
        plt.savefig("drawing_mlp.png", dpi=300, transparent=True)

    plt.show()

# Example usage with 3 layers: input layer with 3 neurons, hidden layer with 6 neurons, and output layer with 2 neurons
draw_mlp([3, 6, 6, 4, 2], figsize=(12, 8))