Weird behaviour when mapping masks

Hello there. I’m new to PyTorch and I’m trying to utilize the ENet NN (from github) to classify underwater images from SUIM dataset and evaluate its performance. Through this post I tried to implement the class index mapping of the masks but when I start training the network, some specific masks cause a crash on cuda. I’ve managed to isolate some masks that cause this strange behaviour, as described below.

Follow the main snippets and code adapted from the previously mentioned post:

import torch.nn.functional as F
from collections import OrderedDict
from train import Train
from test import Test
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as tvtransforms
import transforms as ext_transforms
import torch.optim.lr_scheduler as lr_scheduler
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
from PIL import Image
from enet import ENet
from iou import IoU
import utils
from torchvision import models
class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths, transform=None, train=True):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transform = transform
        self.img = None
        self.msk = None
        self.mapping = {
            (0, 0, 0) : 0, # 'Background'
            (0, 0, 255) : 1, # 'Human Divers'
            (0, 255, 0) : 2, # 'Aquatic Plants and Sea-Grass'
            (0, 255, 255) : 3, # 'Wrecks and Ruins'
            (255, 0, 0): 4, # 'Robots'
            (255, 0, 255) : 5, # 'Reefs and Intertebrates'
            (255, 255, 0) : 6, # 'Fishs and Vertebrates'
            (255, 255, 255) : 7 # 'Sea-Floor and Rocks'
        }
        
    def mask_to_class_rgb(self, mask):
        mask = np.array(mask)
        mask = mask[...,:3]
        mask = torch.from_numpy(mask)
        mask = torch.squeeze(mask)

        print('Unique values in rgb: ', torch.unique(mask))
        class_mask = mask
        class_mask = class_mask.permute(2, 0, 1).contiguous()
        h, w = class_mask.shape[1], class_mask.shape[2]
        mask_out = torch.empty(h, w, dtype=torch.long)
        
        for i in self.mapping:
            idx = (class_mask == torch.tensor(i, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3)
            mask_out[validx] = torch.tensor(self.mapping[i], dtype=torch.long)
            
        print('Unique values mapped: ', torch.unique(mask_out))
        
        return mask_out
    

    def __getitem__(self, index):
        img = Image.open(self.image_paths[index]).convert("RGB")
        msk = Image.open(self.target_paths[index])
                
        img_new = tvtransforms.Resize((std_size, std_size))(img)
        mask_new = tvtransforms.Resize((std_size, std_size),  Image.NEAREST)(msk)

        if self.transform is not None:    
            
            img_new = self.transform(img_new)
            
            img_new = img_new.float()
        else:
            img_new = img_new.float()
            
                
        norm = tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        img_new = norm(img_new)
        
        mask_new = self.mask_to_class_rgb(mask_new)
        
        return img_new, mask_new
        
    
    def __len__(self):
        return len(self.image_paths)

This is the training snippet:

best_miou = 0
if 'start_epoch' in globals():
    start_epoch = start_epoch
else:
    start_epoch = 0

for epoch in range(start_epoch, num_epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))
        
        my_lr_scheduler.step()

        epoch_loss, (iou, miou) = train.run_epoch(True) # Decides whether to print loss at each step

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if (epoch + 1) % 10 == 0 or epoch + 1 == num_epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(True)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == num_epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint_mod(net, optimizer, epoch + 1, best_miou, 'teste', 'save')

Here is the output:

>>>> [Epoch: 0] Training
Unique values in rgb:  tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  55,  56,
         59,  60,  61,  62,  63,  64,  65,  66,  67,  69,  70,  72,  73,  74,
         75,  76,  80,  81,  83,  84,  86,  87,  91,  93,  94,  95,  97, 119,
        161, 162, 166, 167, 170, 171, 172, 175, 176, 177, 178, 179, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214,
        215, 216, 217, 218, 219, 221, 222, 223, 224, 225, 226, 227, 228, 229,
        230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243,
        244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255],
       dtype=torch.uint8)
Unique values mapped:  tensor([                0, 42915137138293276, 43204295812155417,
         ..., 66953201512118048, 67517152192928824,
        67519561669712197])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-59-80f00525c548> in <module>()
     10         my_lr_scheduler.step()
     11 
---> 12         epoch_loss, (iou, miou) = train.run_epoch(True) # Decides whether to print loss at each step
     13 
     14         print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2218         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2219     elif dim == 4:
-> 2220         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2221     else:
   2222         # dim == 3 or dim > 4

IndexError: Target 52516180046732819 is out of bounds.

I’m attaching the isolated mask that causes this behaviour. Other masks that are similar to this one cause equal outcome. Also, I’ve edited this mask and made it entirely black, which didn’t crash during training.

Any help would be greatly appreciated!

Since mask_out is created as an empty tensor here: torch.empty(h, w, dtype=torch.long), it will contain uninitialized values at the beginning.
The loop over self.mapping should assign valid class indices to the mask, which doesn’t seem to be the case and some indices are still uninitialized.
I would recommend to recheck the mapping and make sure that only the RGB values in self.mapping are really used in the mask.
You should be able to get all unique color values via:

mask = torch.zeros(3, 24, 24)
mask[:, 0, 0] = torch.tensor([1., 2., 3.])
mask[:, 1, 1] = torch.tensor([4., 5., 6.])
mask[:, 2, 2] = torch.tensor([7., 8., 9.])

print(mask.view(3, -1).unique(dim=1))
> tensor([[0., 1., 4., 7.],
          [0., 2., 5., 8.],
          [0., 3., 6., 9.]])

Dear ptrblck, thanks for the quick response.

I’m still confused on how I can insert your suggested snippet, given the appropriate changes, into the code itself. Would you mind to elaborate it a little bit further? I’ve had some experience with Python before, but still learning and improving with Pytorch.

By the way, just for testing, changing torch.empty for torch.zeros didn’t crash cuda, confirming what you proposed. But I still can’t understand why, for an image with only black and red pixels, the resulting tensor with unique values in rgb from the mask are still ranging from 0 to 255 (with some intermediate values missing as well, e.g., 88, 89…). I’m still trying to debug why the code is reading those values out of the mask.

Any further clarifications would be greatly appreciated and once again thanks for the support.

You can use the print(mask..view(3, -1).unique(dim=1)) operation to print all unique RGB values in your masks and make sure that your dict used to map RGB values to class indices indeed contains all unique colors.

Currently I guess your dict is missing some colors and is thus not setting any class index the corresponding pixels in the mask, which results in undefined values (from the empty initialization).

Dear ptrblck, thanks again for the response.

So sorry for being absent these past weeks, haven’t had time to work on it properly.

I tried using your code adapted to my needs but no success so far. Even though the image only appears to have black and red on it, the code is still returning all values from range [0,255].

What I tried to do is apply this snippet to probe for problems:

testmask = Image.open(mask_path[0])
testmask = tvtransforms.Resize((std_size, std_size))(testmask)
pix_val = list(testmask.getdata())
print(pix_val)

pix_val is returning already different values from (0,0,0) and (255,0,0), which are the values I expect from a mask presenting only red and black pixels (please remind that the dataset was built from the coding presented on the dict we were discussing previously).

So is it an indication that the mask itself could be the problem? Or am I missing something?

Dataset I am currently using: SUIM Dataset.

I kindly appreciate the support so far!

In your code snippet you’ve dropped the NEAREST interpolation method again for Resize.
By default a LINEAR interpolation will be used, which will create colors “between” black and red.

1 Like

My thesis advisor suggested me to open these images on any image editor and zoom in to see if there was something strange on them, and nailed it. Apparently the dataset has some images that went through some kind of image processing/resize and did distort some of the pixels, as shown on images below (these images were extracted from raw dataset, before any image processing done by me).

Sharp edges:

Smoothened edges:

So I guess I’ll have to modify the code snippet applying some kind of threshold when designating class indices to pixels.

Thanks for the support, the last reply shed some light on were to investigate the problem and finally guide to the necessary adjustments.

1 Like

Funny, but i ran into the exact same problem.

Using imagemagick to resize my masks resulted in mask images that contained interpolated RGB-color-values that didn’t match any entry of the dict. Imagemagick seems to use bilinear interpolation as the default method here.

So like @ptrblck and @robsoncs already stated, it is essential that operations on images should not create interpolated values. For resizing use an interpolation method like Nearest Neighbour.