import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
import torch

# Use a pretrained tokenizer (e.g., from Hugging Face)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = "Attention-based architectures are currently dominant."
tokens = tokenizer.tokenize(text)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

import torch.nn.functional as F

# Use a pretrained model (e.g., from Hugging Face)
model = AutoModel.from_pretrained("bert-base-uncased")

# Convert token ids to tensor and add batch dimension
input_ids = torch.tensor([token_ids])

# Get the attention weights from the model
with torch.no_grad():
    outputs = model(input_ids, output_attentions=True)
    attention = outputs.attentions  # Get attention weights from the model outputs

# Extract the attention weights for the first layer and first head
attention_weights = attention[0][0][0].numpy()

# Plotting the heatmap
plt.figure(figsize=(10, 8))
ax = sns.heatmap(
    attention_weights,
    xticklabels=tokens,
    yticklabels=tokens,
    cmap="viridis",
    cbar=True,
    square=True,
    annot=False,
    fmt=".2f",
)

# Increase font size
ax.tick_params(axis='both', which='major', labelsize=18)
# ax.set_title("Simulated Attention Heatmap")
# ax.set_xlabel("Key Tokens")
# ax.set_ylabel("Query Tokens")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()

if True:
    plt.savefig('single-head-attention-heatmap.png', dpi=300)

plt.show()
