import numpy as np
import matplotlib.pyplot as plt

def get_positional_encoding(max_seq_len, d_model):
    positional_encoding = np.zeros((max_seq_len, d_model))
    for pos in range(max_seq_len):
        for i in range(0, d_model, 2):
            positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))
            if i + 1 < d_model:
                positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
    return positional_encoding

max_seq_len = 100
d_model = 512
positional_encoding = get_positional_encoding(max_seq_len, d_model)

plt.figure(figsize=(15, 5))
plt.pcolormesh(positional_encoding, cmap='viridis')
plt.xlabel('Depth')
plt.xlim((0, d_model))
plt.ylabel('Position')
plt.ylim((0, max_seq_len))
plt.colorbar()
plt.title('Positional Encoding')

if True:
    plt.savefig("positional-encodings.png",dpi=300)

plt.show()