Problem with bert predicting masked tokens

I have implement a bert model for masked token prediction using x-transformer library I am getting results where its changing tokens which are unmasked as well I don’t know why this is happening. can someone point out the problem for me.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from x_transformers import TransformerWrapper, Encoder
import torch.nn.functional as F
from collections import Counter
import json
import re
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam

class Tokenizer:
    def __init__(self):
        self.vocab = None
        self.rev_vocab = None
        self.special_tokens = ["<start>", "<end>", "<pad>", "<unk>", "<mask>"]

    def build_vocab(self, sequences):
        vocab_count = Counter()
        for sequence in sequences:
            vocab_count.update(sequence.split())
        for token in self.special_tokens:
            vocab_count[token] = 1
        self.vocab = {word: idx for idx, (word, _) in enumerate(vocab_count.most_common())}
        self.rev_vocab = {idx: word for word, idx in self.vocab.items()}

    def save(self, save_path):
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f)

    def load(self, load_path):
        with open(load_path, 'r', encoding='utf-8') as f:
            self.vocab = json.load(f)
        self.rev_vocab = {int(idx): word for word, idx in self.vocab.items()}

    def encode(self, sequence):
        return [self.vocab.get(word, self.vocab.get("<unk>")) for word in sequence.split()]

    def decode(self, tokens):
        return ' '.join([self.rev_vocab.get(idx, "<unk>") for idx in tokens])

def read_and_preprocess_sequences(fname, tokenizer):
    with open(fname, 'r', encoding='utf-8') as f:
        sequences = []
        for line in f:
            sequence = re.sub(r'\n+', ' ', line.strip())
            sequence = re.sub(r'([+-])(\w)', r'\1 \2', sequence)
            sequences.append(sequence)
    tokenizer.build_vocab(sequences)
    return sequences

tokenizer = Tokenizer()
fname = '/scratch/harsha.vasamsetti/testing/sampled_2_train_sample.sli'  # Update this path
sequences = read_and_preprocess_sequences(fname, tokenizer)
tokenizer.save('tokenizer.json')
tokenizer.load('tokenizer.json')

sequence_list = [tokenizer.encode(sequence) for sequence in sequences]

max_seq_length = max(len(sequence) for sequence in sequence_list)
block_size = max_seq_length

data_padded = pad_sequence([torch.tensor(sequence, dtype=torch.long) for sequence in sequence_list], batch_first=True, padding_value=tokenizer.vocab['<pad>'])

n = int(0.9 * len(data_padded))
train_data, val_data = data_padded[:n], data_padded[n:]
train_dataset = SequenceDataset(train_data)
val_dataset = SequenceDataset(val_data)

class MoleculeModel(nn.Module):
    def __init__(self, vocab_size, max_seq_len=100):
        super(MoleculeModel, self).__init__()
        self.max_seq_len = max_seq_len
        self.transformer = TransformerWrapper(
            num_tokens=vocab_size,
            max_seq_len=self.max_seq_len,
            attn_layers=Encoder(
                dim=256,
                depth=4,
                heads=4,
                cross_attend=False,
            )
        )

    def forward(self, x, mask=None):
        return self.transformer(x, mask=mask)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MoleculeModel(vocab_size=len(tokenizer.vocab), max_seq_len=block_size + 1).to(device)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

optimizer = Adam(model.parameters(), lr=3e-4)
epochs = 10

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch_idx, (input, target) in enumerate(train_loader):
        input, target = input.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(input)
        loss = F.cross_entropy(output.transpose(1, 2), target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

    print(f"Epoch {epoch}, Average Loss: {total_loss / len(train_loader)}")

def mask_tokens(tokenizer, sequence, mask_token="<mask>", mask_prob=0.2):
    """
    Randomly masks tokens in a sequence with a given probability.
    Returns the masked sequence and a mask indicating the positions of the masked tokens.
    """
    tokens = tokenizer.encode(sequence)
    mask_index = tokenizer.vocab.get(mask_token, tokenizer.vocab["<unk>"])  # Fallback to "<unk>" if "<mask>" not found
    output_tokens = tokens[:]
    output_labels = [-100] * len(tokens)  # Using -100 to ignore indices during loss computation

    for i in range(len(tokens)):
        if torch.rand(1) < mask_prob:
            output_tokens[i] = mask_index
            output_labels[i] = tokens[i]  # Original token is used as the label for prediction

    return output_tokens, output_labels

def predict_masked_tokens(model, tokenizer, tokens, device='cpu'):
    """
    Predicts masked tokens in a given sequence of token IDs.
    """
    input_seq = torch.tensor([tokens], dtype=torch.long).to(device)
    
    with torch.no_grad():
        predictions = model(input_seq)
        predicted_indices = predictions.argmax(dim=-1).squeeze().tolist()

    predicted_tokens = [tokenizer.rev_vocab.get(idx, "<unk>") for idx in predicted_indices]
    return predicted_tokens

def decode_tokens(tokenizer, token_ids):
    """
    Decodes a list of token IDs back into their string representations.
    """
    return [tokenizer.rev_vocab.get(id, "<unk>") for id in token_ids]

def display_masked_and_predicted(tokenizer, masked_tokens, predicted_tokens):
    """
    Displays the masked tokens and their predicted replacements side by side.
    """
    masked_sequence = decode_tokens(tokenizer, masked_tokens)
    # For displaying purposes, replacing the mask index with the actual mask token
    mask_token = "<mask>"
    mask_index = tokenizer.vocab.get(mask_token, tokenizer.vocab["<unk>"])
    masked_sequence_display = [mask_token if token == mask_index else token for token in masked_sequence]

    print("Masked Sequence: ", masked_sequence_display)
    print("Predicted Tokens:", predicted_tokens)

# Example usage of mask_tokens
sequence = "Tm Ni Tm Ni Ni Ni 0 5 - - o 3 0 o o + 1 4 o o - 1 3 o o o 2 3 o o o 2 5 - o o 5 4 o + o 0 3 - o - 2 5 o o - 2 1 o o + 2 3 o o o 2 0 + o o 4 0 o o + 4 0 + o + 0 2 o - o 0 1 o o o 5 2 + o o 0 5 o - - 1 0 o + o 4 0 o + + 5 1 o + o 2 5 o o o 5 2 + o o 3 1 o o o 3 4 o - o 3 0 + o o 1 0 + o o 4 0 + + o 3 2 o o o 1 2 o o - 4 1 o + o 0 5 o - - 4 2 + o o 2 3 - + o 2 4 o - o 3 1 o o o 1 2 o o o 5 4 o o o 4 3 + o o 5 0 + + +"
masked_tokens, output_labels = mask_tokens(tokenizer, sequence)

# Example usage of predict_masked_tokens
# Ensure you provide the sequence with actual mask tokens represented by their IDs
predicted_tokens = predict_masked_tokens(model, tokenizer, masked_tokens, device=device)

display_masked_and_predicted(tokenizer, masked_tokens, predicted_tokens)

output:

Masked Sequence:  ['Tm', 'Ni', 'Tm', 'Ni', '<mask>', 'Ni', '<mask>', '5', '-', '-', 'o', '3', '0', 'o', 'o', '+', '<mask>', '4', 'o', 'o', '-', '1', '3', '<mask>', 'o', 'o', '2', '3', 'o', 'o', 'o', '2', '5', '-', 'o', 'o', '5', '4', '<mask>', '+', 'o', '0', '3', '-', 'o', '-', '2', '5', 'o', 'o', '<mask>', '<mask>', '1', '<mask>', 'o', '+', '2', '3', 'o', 'o', 'o', '2', '0', '<mask>', 'o', 'o', '4', '0', '<mask>', 'o', '+', '4', '0', '+', 'o', '+', '<mask>', '2', 'o', '-', 'o', '0', '1', 'o', 'o', 'o', '5', '2', '+', 'o', 'o', '0', '5', 'o', '<mask>', '-', '1', '0', 'o', '+', '<mask>', '<mask>', '0', '<mask>', '<mask>', '+', '<mask>', '1', 'o', '<mask>', 'o', '2', '5', '<mask>', '<mask>', 'o', '5', '<mask>', '+', 'o', 'o', '3', '<mask>', 'o', '<mask>', 'o', '3', '<mask>', '<mask>', '-', 'o', '<mask>', '0', '+', 'o', '<mask>', '<mask>', '0', '+', 'o', 'o', '4', '0', '<mask>', '+', 'o', '3', '2', 'o', 'o', '<mask>', '<mask>', '2', 'o', 'o', '-', '4', '1', 'o', '+', 'o', '<mask>', '5', 'o', '-', '-', '4', '2', '+', 'o', 'o', '2', '3', '-', '+', 'o', '2', '4', 'o', '-', 'o', '3', '<mask>', 'o', 'o', '<mask>', '1', '2', 'o', '<mask>', 'o', '5', '4', 'o', 'o', '<mask>', '4', '<mask>', '+', 'o', 'o', '5', '<mask>', '<mask>', '<mask>', '+']
Predicted Tokens: ['5', '13', '-', '13', '9', '3', '5', '-', '-', 'o', '3', '0', 'o', 'o', '+', 'o', '2', 'o', 'o', '-', '1', '3', 'o', 'o', 'o', '2', '3', 'o', 'o', 'o', '2', '5', '-', 'o', 'o', '-', '4', 'o', 'o', 'o', '0', 'o', '-', 'o', '+', '3', '5', 'o', 'o', 'o', '0', '1', '+', '+', 'o', '4', '3', 'o', 'o', '+', '2', '0', '+', '+', 'o', '4', '0', 'o', '9', 'o', '4', '0', '+', 'o', 'o', '4', '2', 'o', '+', '0', '0', '+', '+', 'o', 'o', '3', 'o', '+', 'o', 'o', '0', '+', 'o', 'o', '+', '1', '0', 'o', '+', 'o', '9', '5', '+', '9', '5', '0', '3', 'o', 'o', '9', '2', 'o', 'o', 'o', 'o', '0', '0', '-', 'o', 'o', '3', 'o', 'o', 'o', '9', '3', '0', 'o', '9', 'o', '4', '1', '+', 'o', '+', '9', '0', 'o', 'o', 'o', '4', '0', 'o', '+', '+', '3', '2', 'o', 'o', 'o', '5', 'o', 'o', 'o', '+', '0', '0', 'o', '+', 'o', '0', '5', '+', 'o', 'o', '4', '2', '+', 'o', 'o', '5', '-', '-', '-', '+', '2', 'o', 'o', 'o', 'o', '3', '0', '5', '+', '+', '5', '2', 'o', 'o', '9', 'o', 'o', 'o', 'o', '3', '5', 'o', '1', 'o', 'o', '2', '-', 'o', '5', '+', '4']```