Cross Entropy Loss giving nan only on Validation Data

Hi @ptrblck ,
So i am using Segmentation_Models_pytorch_lib for a multiclass classification task where each pixel gets a prediction for the population living in it based on a input that consists of an rgb image and corresponding height values.
I am trying to use the cross_entropy_loss for this task.
This is the model i use:

MULTICLASS_MODE: str = "multiclass"
ENCODER = 'efficientnet-b2'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = None
DEVICE = 'cuda'
CLASSES = [0, 1, 2, 3, 4, 5, 6, 7]


model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    encoder_depth=5,
    in_channels=4,
    classes=8,
    activation=ACTIVATION,
    decoder_use_batchnorm=True,)

My input has shape [6,8,512,512] and my target has shape [6, 1,512,512 ] containing the class index, which i squeeze at dimension 1 so it fits the requirements from the documentations with the resulting shape for the target being [6,512,512]
For the Training the loss seems to decrease but for the validation it prints out nan for loss value from the first iteration on.

max_score = 50
for i in range(0, 750):
print(optimizer.param_groups[0]['lr'])
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    scheduler.step(valid_logs['cross_entropy'])

    if max_score > valid_logs['cross_entropy']:
        max_score = valid_logs['cross_entropy']
        torch.save(model, './best_model.pth')
        print('Model saved!')

this is what the train and val epochs look like:


import sys
import torch
from tqdm import tqdm as tqdm
from .meter import AverageValueMeter


class Epoch:

    def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True):
        self.model = model
        self.loss = loss
        self.metrics = metrics
        self.stage_name = stage_name
        self.verbose = verbose
        self.device = device

        self._to_device()

    def _to_device(self):
        self.model.to(self.device)
        self.loss.to(self.device)
        for metric in self.metrics:
            metric.to(self.device)

    def _format_logs(self, logs):
        str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
        s = ', '.join(str_logs)
        return s

    def batch_update(self, x, y):
        raise NotImplementedError

    def on_epoch_start(self):
        pass

    def run(self, dataloader):

        self.on_epoch_start()

        logs = {}
        loss_meter = AverageValueMeter()
        metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}

        with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
            for x, y,_,_ in iterator:
                x, y = x.to(self.device), y.to(self.device)
                loss, y_pred = self.batch_update(x, y)

                # update loss logs
                loss_value = loss.cpu().detach().numpy()
                loss_meter.add(loss_value)
                loss_logs = {self.loss.__name__: loss_meter.mean}
                logs.update(loss_logs)

                # update metrics logs
                for metric_fn in self.metrics:
                    metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
                    metrics_meters[metric_fn.__name__].add(metric_value)
                metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
                logs.update(metrics_logs)

                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)

        return logs


class TrainEpoch(Epoch):

    def __init__(self, model, loss, metrics, optimizer, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='train',
            device=device,
            verbose=verbose,
        )
        self.optimizer = optimizer

    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, x, y):
        self.optimizer.zero_grad()
        prediction = self.model.forward(x)
        loss = self.loss(prediction, y)
        loss.backward()
        self.optimizer.step()
        return loss, prediction


class ValidEpoch(Epoch):

    def __init__(self, model, loss, metrics, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='valid',
            device=device,
            verbose=verbose,
        )

    def on_epoch_start(self):
        self.model.eval()

    def batch_update(self, x, y):
        with torch.no_grad():
            prediction = self.model.forward(x)
            loss = self.loss(prediction, y)
        return loss, prediction

and this is what my loss looks like:

import warnings
from torch import nn, Tensor
import segmentation_models_pytorch.base
import torch.nn.functional as F
import torch
from torch.nn import _reduction as _Reduction

from torch import Tensor
from typing import Callable, Optional

__all__ = ['CrossEntropyLoss']


class _Loss(nn.Module):
    reduction: str

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction: str = _Reduction.legacy_get_string(
                size_average, reduce)
        else:
            self.reduction = reduction


class _WeightedLoss(_Loss):
    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)
        self.weight: Optional[Tensor]


class CrossEntropyLoss(_WeightedLoss):
    __constants__ = ['ignore_index', 'reduction', 'label_smoothing']
    ignore_index: int
    label_smoothing: float

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None:
        super(CrossEntropyLoss, self).__init__(
            weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> Tensor:
        print(y_true.shape)
        print("pred", y_pred.shape)
        y_true = y_true.squeeze(1)
        print(y_true.shape)

        #
        # input = torch.argmax(input)
        loss = F.cross_entropy(y_pred, y_true, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction,
                               label_smoothing=self.label_smoothing)
        return loss

Please let me know if you got any idea on why my val_loss is nan, since i also tested other loss and they work just fine.

Hi jhans!

Could you find a specific batch – or better still, a specific sample – from
your validation dataset that produces a cross-entropy loss of nan?

Then print out and / or store the prediction and y for which loss in nan
and use these two tensors to reproduce the issue. Note, you won’t need
your model nor dataloader nor optimizer to reproduce the issue – just
these two stored tensors should suffice.

Please then post a fully-self-contained, runnable script that reproduces the
issue, together with the output you get. (Just hard-code the prediction
and y tensors into the script.)

That will be the best way to start tracking down what might be a possible
cause.

Best.

K. Frank

Hi KFrank, thanks for your time, i got it solved. Actualy it was the ignore_index that caused the problem.
I set it to 0, but then i was getting nan values, so changing that fixed the loss.
Best
Johannes