CrossEntropy only works with all-zero mask

I am training a segmentation model (DeepLabV3+), created using the PyTorch segmentation models library, to identify objects in RGB images. I have three classes (+ background), which are represented as red, green, blue and black (for the background). The RGB masks are converted into a 2d tensor by using a mapping (thanks to answer given by @ptrblck!). I use a cross-entropy loss function.

This is my custom dataset class:

class SegmentationDataset(VisionDataset):
    def __init__(self,
                 root: str,
                 image_folder: str,
                 mask_folder: str,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 seed: int = None,
                 fraction: float = None,
                 subset: str = None,
                 image_color_mode: str = "rgb",
                 mask_color_mode: str = "rgb") -> None:


        # Creating color mapping
        self.mapping = {
                    (0,0,0): 0,
                    (255,0,0): 1,
                    (0,255,0): 2,
                    (0,0,255): 3

        # Setting transform and target_transform
        self.transform = transform
        self.target_transform = target_transform

        # Creating paths to image and mask folders
        image_folder_path = Path(self.root) / image_folder
        mask_folder_path = Path(self.root) / mask_folder

        # Raising errors if paths do not exist
        if not image_folder_path.exists():
            raise OSError(f"{image_folder_path} does not exist.")
        if not mask_folder_path.exists():
            raise OSError(f"{mask_folder_path} does not exist.")

        # Raising errors if selected color mode is not possible
        if image_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
        if mask_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."

        # Initiating color modes for image and mask
        self.image_color_mode = image_color_mode
        self.mask_color_mode = mask_color_mode

        if not fraction:
            self.image_names = sorted(image_folder_path.glob("*"))
            self.mask_names = sorted(mask_folder_path.glob("*"))
            if subset not in ["Train", "Test"]:
                raise (ValueError(
                    f"{subset} is not a valid input. Acceptable values are Train and Test."
            self.fraction = fraction
            self.image_list = np.array(sorted(image_folder_path.glob("*")))
            self.mask_list = np.array(sorted(mask_folder_path.glob("*")))
            if seed:
                indices = np.arange(len(self.image_list))
                self.image_list = self.image_list[indices]
                self.mask_list = self.mask_list[indices]
            if subset == "Train":
                self.image_names = self.image_list[:int(
                    np.ceil(len(self.image_list) * (1 - self.fraction)))]
                self.mask_names = self.mask_list[:int(
                    np.ceil(len(self.mask_list) * (1 - self.fraction)))]
                self.image_names = self.image_list[
                    int(np.ceil(len(self.image_list) * (1 - self.fraction))):]
                self.mask_names = self.mask_list[
                    int(np.ceil(len(self.mask_list) * (1 - self.fraction))):]

    def __len__(self) -> int:
        return len(self.image_names)

    def __getitem__(self, index: int) -> Any:
        image_path = self.image_names[index]
        mask_path = self.mask_names[index]

        with open(image_path, "rb") as image_file, open(mask_path, "rb") as mask_file:
            image =

            if self.image_color_mode == "rgb":
                image = image.convert("RGB")
            elif self.image_color_mode == "grayscale":
                image = image.convert("L")

            mask =
            if self.mask_color_mode == "rgb":
                mask = mask.convert("RGB")

                # Transforming mask to tensor
                mask = transforms.ToTensor()(mask)

                # Creating empty 2d mask
                empty_mask = torch.empty(512, 512, dtype=torch.long)

                # Loop through all colours
                for k in self.mapping:
                  # Get all indices for current class
                  idx = (mask==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
                  validx = (idx.sum(0) == 3)  # Check that all channels match
                  empty_mask[validx] = torch.tensor(self.mapping[k], dtype=torch.long)

            elif self.mask_color_mode == "grayscale":
                mask = mask.convert("L")

            sample = {"image": image, "mask": empty_mask}
            if self.transform:
                sample["image"] = self.transform(sample["image"])

            return sample

I use the following function to train my model:

def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs):
    # Getting current time
    since = time.time()

    # Making a copy of the model's learnable parameters
    best_model_wts = copy.deepcopy(model.state_dict())

    # Creating variable for best loss (and setting it high)
    best_loss = 1e10

    # Using GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initializing the log file for training and testing loss and metrics
    fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
        [f'Train_{m}' for m in metrics.keys()] + \
        [f'Test_{m}' for m in metrics.keys()]

    with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    # Looping through epochs
    for epoch in range(1, num_epochs + 1):
        # Printing current epoch (of total epochs)
        print('Epoch {}/{}'.format(epoch, num_epochs))

        # Printing dashed line
        print('-' * 10)

        # Each epoch has a training and validation phase
        # Initializing batch summary
        batchsummary = {a: [0] for a in fieldnames}

        for phase in ['Train', 'Test']:
            if phase == 'Train':
                model.train()  # Set model to training mode
                model.eval()  # Set model to evaluate mode

            # Iterating over batch
            for sample in tqdm(iter(dataloaders[phase]), desc="Training..."):
                inputs = sample['image'].to(device)
                masks = sample['mask'].to(device)

                masks = masks.long()

                # zero the parameter gradients

                # Tracking history if only in train
                with torch.set_grad_enabled(phase == 'Train'):
                    # Calculate outputs
                    outputs = model(inputs)

                    # DEBUGGING
                    print("SUM of output:", torch.sum(outputs))
                    print("SUM of mask:", torch.sum(masks))   
                    # Set loss
                    loss = criterion(

                    y_pred = torch.argmax(outputs, dim=1).long().cpu().numpy().ravel()
                    y_true = masks.cpu().numpy().ravel()

                    for name, metric in metrics.items():
                        if name == "jaccard_score":
                                metric(y_true, y_pred, average="micro"))
                                metric(y_true.astype('uint8'), y_pred))

                    # backward + optimize only if in training phase
                    if phase == 'Train':
            batchsummary['epoch'] = epoch
            epoch_loss = loss
            batchsummary[f'{phase}_loss'] = epoch_loss.item()
            print('{} Loss: {:.4f}'.format(phase, loss))
        for field in fieldnames[3:]:
            batchsummary[field] = np.mean(batchsummary[field])
        with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            # deep copy the model
            if phase == 'Test' and loss < best_loss:
                best_loss = loss
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Lowest Loss: {:4f}'.format(best_loss))

    # load best model weights
    return model

And finally, to actually initiate the model and train it on my dataset, I run the following:

import segmentation_models_pytorch as smp

model = smp.DeepLabV3Plus(
  in_channels=3, # for RGB              
  classes=4      # for 3 classes: Palm, erosion barriers and path              

# Specify the loss function
criterion = torch.nn.CrossEntropyLoss()
# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Specify the evaluation metrics
metrics = {
    'jaccard_score': jaccard_score

# Creating the dataloader
image_datasets = {
        x: SegmentationDataset(

        for x in ['Train', 'Test']

my_dataloaders = {
        x: DataLoader(
        for x in ['Train', 'Test']

# Training the model
_ = train_model(

This code works, as long as the sum of the masks is 0, which I output in the train_model() function:

print("SUM of output:", torch.sum(outputs))
print("SUM of mask:", torch.sum(masks))   

Clearly, this is usually not the case (although I do have ‘background-only’ training pairs, with no objects of interest). I am confused as to why the model only works when my mask only consists of background and I cannot interpret the traceback. Any help/leads are much appreciated, thank you!

Error traceback

RuntimeError                              Traceback (most recent call last)

<ipython-input-7-a35552333777> in <module>()
     78     bpath=log_dir,
     79     metrics=metrics,
---> 80     num_epochs=epochs
     81 )

4 frames

<ipython-input-6-ee5bb6f99dd3> in train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs)
     78                     loss = criterion(
     79                         input=outputs,
---> 80                         target=masks
     81                       )

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/ in forward(self, input, target)
   1046         assert self.weight is None or isinstance(self.weight, Tensor)
   1047         return F.cross_entropy(input, target, weight=self.weight,
-> 1048                                ignore_index=self.ignore_index, reduction=self.reduction)

/usr/local/lib/python3.7/dist-packages/torch/nn/ in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2691     if size_average is not None or reduce is not None:
   2692         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2693     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

/usr/local/lib/python3.7/dist-packages/torch/nn/ in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2388         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2389     elif dim == 4:
-> 2390         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2391     else:
   2392         # dim == 3 or dim > 4

RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THCUNN/generic/

Could you print the shape of the output of your model as well as the targets?
Based on the error description I guess that your output might only have a single channel and thus nn.CrossEntropyLoss would only accept a single class (which is not very useful in a multi-class use case :wink: ).