Loss function gives NaN

My loss function gives NaN. What can be the cause of this?

from sklearn import metrics
from sklearn.metrics import f1_score

best_optimizer = 'RMSprop'
BATCHSIZE = 128
epochs = 30


# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True


# Calculate class weights
train_dataset = ImageFolder(train_dir)
targets = train_dataset.targets

class_weights = compute_class_weight(class_weight = 'balanced', classes = np.unique(targets), y = targets)
class_weights_tensor = torch.from_numpy(class_weights)
class_weights_tensor = class_weights_tensor.to(torch.float32)
class_weights_tensor = class_weights_tensor.to(DEVICE)


def define_model():

  model = timm.create_model('tf_efficientnetv2_b0.in1k', pretrained=False, in_chans=3, num_classes=12)

  return model

def get_dataset():
  import torchvision.transforms as transforms
  from torchvision.datasets import ImageFolder

  transform_train = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

  transform_valid = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      ])

  train_dataset = ImageFolder(train_dir, transform=transform_train)
  test_dataset = ImageFolder(test_dir, transform=transform_valid)

  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True, num_workers=2)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCHSIZE, shuffle=False, num_workers=2)

  return train_loader, test_loader


# Generate the model.
model = define_model().to(DEVICE)

# Generate the optimizers.
optimizer_name = best_optimizer
optimizer = getattr(optim, optimizer_name)(model.parameters())

# Generate a loss function
criterion = nn.CrossEntropyLoss(weight= class_weights_tensor)

# Get the FashionMNIST dataset.
train_loader, test_loader = get_dataset()

from torch.cuda.amp import autocast, GradScaler

# Initialize GradScaler
scaler = GradScaler()

for epoch in range(epochs):
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        # Compute the train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    # Compute the train loss and accuracy
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100.0 * correct_train / total_train

    # Evaluate the model on the validation set
    model.eval()
    correct_valid = 0
    total_valid = 0
    y_true, y_pred = [], []
    valid_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            valid_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)

            total_valid += labels.size(0)
            correct_valid += (predicted == labels).sum().item()

            y_true += labels.cpu().tolist()
            y_pred += predicted.cpu().tolist()

    # Compute the validation loss and accuracy
    valid_loss /= len(test_loader)
    valid_accuracy = 100.0 * correct_valid / total_valid

    #calculate the macro f1-score
    f1_score = metrics.f1_score(y_true, y_pred, average='macro')

    # Print the results for this epoch
    print(f"Epoch {epoch}/{epochs} - "
          f"Train Loss: {train_loss:.4f}, Train_Val Accuracy: {train_accuracy:.2f}% - "
          f"Test Loss: {valid_loss:.4f}, Test Accuracy: {valid_accuracy:.2f}%, Macro F1 score: {f1_score:.4f}")
Epoch 0/30 - Train Loss: 4.1348, Train_Val Accuracy: 10.13% - Test Loss: 5.6673, Test Accuracy: 8.05%, Macro F1 score: 0.0348
Epoch 1/30 - Train Loss: nan, Train_Val Accuracy: 8.37% - Test Loss: nan, Test Accuracy: 8.61%, Macro F1 score: 0.0132```

Check where exactly the NaN values are created e.g. by using forward hook printing information about the intermediate output.
E.g. during evaluation a batchnorm layer could create these invalid values, if it received an invalid training batch already containing NaNs or Infs. In this case the running stats will be updated with these invalid values and will thus cause NaN outputs during evaluation.

Thank you for your answer. I figured out the problem is causes by the class weights applied. Maybe I will figure out later exactly what the cause was but because I am in a hurry I did remove the class weights