NaN values in Output when training randomly initialized transformer

Hello everyone, I am new to Pytorch and definitely not good, but I have to do this for class and am stuck at this problem. This is my code I am using to train a randomly initialized transformer.

import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np

train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=16)
validation_dataloader = DataLoader(validation_dataset, sampler=SequentialSampler(validation_dataset), batch_size=16)

class CustomTransformerEncoder(nn.Module):
def init(self, num_tokens, dim_model, num_heads, num_encoder_layers, num_classes):
super(CustomTransformerEncoder, self).init()
self.embedding = nn.Embedding(num_tokens, dim_model)
encoder_layer = TransformerEncoderLayer(d_model=dim_model, nhead=num_heads, batch_first=True)
self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)
self.final_layer = nn.Linear(dim_model, num_classes)

def forward(self, src, mask=None):
    embedded_src = self.embedding(src)
    output = self.transformer_encoder(embedded_src, src_key_padding_mask=mask)
    output = self.final_layer(output[:, 0, :])
    return output

def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

model = CustomTransformerEncoder(
num_tokens=30522,
dim_model=512,
num_heads=8,
num_encoder_layers=6,
num_classes=2,
)

model.apply(init_weights)

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
epochs = 1
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

loss_function = nn.CrossEntropyLoss()

def flat_accuracy(preds, labels):
pred_flat = preds.argmax(dim=1).flatten()
labels_flat = labels.flatten()
return (pred_flat == labels_flat).type(torch.float).mean().item()

training_loss_values =
validation_loss_values =
validation_accuracy_values =

for epoch in range(epochs):
print(‘======== Epoch {:} / {:} ========’.format(epoch + 1, epochs))
total_train_loss = 0
model.train()

for step, batch in enumerate(train_dataloader):
    b_input_ids, b_input_mask, b_labels = tuple(t.to(device) for t in batch)

    optimizer.zero_grad()

    outputs = model(b_input_ids, b_input_mask.bool())

    loss = loss_function(outputs, b_labels)

    if torch.isnan(loss):
        print("NaN detected in loss.")
        break

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    total_train_loss += loss.item()

avg_train_loss = total_train_loss / len(train_dataloader)
print("  Average training loss: {0:.2f}".format(avg_train_loss))
training_loss_values.append(avg_train_loss)

print("Running Validation...")
model.eval()
total_eval_accuracy = 0
total_eval_loss = 0

for batch in validation_dataloader:
    b_input_ids, b_input_mask, b_labels = tuple(t.to(device) for t in batch)

    with torch.no_grad():
        outputs = model(b_input_ids, b_input_mask.bool())

        loss = loss_function(outputs, b_labels)
    total_eval_loss += loss.item()
    total_eval_accuracy += flat_accuracy(outputs, b_labels)

avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
print("  Validation Accuracy: {0:.2f}".format(avg_val_accuracy))
validation_accuracy_values.append(avg_val_accuracy)

avg_val_loss = total_eval_loss / len(validation_dataloader)
print("  Validation loss: {0:.2f}".format(avg_val_loss))
validation_loss_values.append(avg_val_loss)

and this is the output i get

======== Epoch 1 / 1 ========
NaN detected in loss.
Average training loss: 0.00
Running Validation…
Validation Accuracy: 0.50
Validation loss: nan

how can I handle these NaN values?
Thanks in advance!