Training Semantic Segmentation


I am trying to reproduce PSPNet using PyTorch and this is my first time creating a semantic segmentation model. I understand that for image classification model, we have RGB input = [h,w,3] and label or ground truth = [h,w,n_classes]. We then use the trained model to create output then compute loss. For example, output = model(input); loss = criterion(output, label).

However, in semantic segmentation (I am using ADE20K datasets), we have input = [h,w,3] and label = [h,w,3] and we will then encode the label to [h,w,1]. ADE20K has a total of 19 classes, so out model will output [h,w,19]. I am confused how can we then compute for the loss as the dimension of the label and the output are clearly different.

Any help or guidance on this will be greatly appreciated!

Since PSPNet uses convolutions, you should pass your input as [batch_size, channels height, width] (channels-first).

It looks like your targets are RGB images, where each color encodes a specific class.
If that’s the case, you should map the colors to class indices.
I’m not familiar with the ADE20K dataset, but you might find a mapping between the colors and class indices somwhere online.
If not, you can just create your own mapping, e.g. using a dict and transform the targets.

I mapped the target RGB into a single channel uint16 images where the values of the pixels indicate the classes. The formula is ObjectClassMasks = (uint16(R)/10)*256+uint16(G) where R is the red channel and G is the green channel. I don’t think there is a way to convert that into an image with [n_classes height width].

Also, can you provide more information on how to create my own mapping?

Thank you very much

Is the formula used for the color - class mapping?

Here is an example how to create your own mapping:

import torch

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Create dummy target image
nb_classes = 19 - 1 # 18 classes + background
idx = np.linspace(0., 1., nb_classes)
cmap ='viridis')
rgb = cmap(idx, bytes=True)[:, :3]  # Remove alpha value

h, w = 190, 100
rgb = rgb.repeat(1000, 0)
target = np.zeros((h*w, 3), dtype=np.uint8)
target[:rgb.shape[0]] = rgb
target = target.reshape(h, w, 3)

plt.imshow(target) # Each class in 10 rows

# Create mapping
# Get color codes for dataset (maybe you would have to use more than a single
# image, if it doesn't contain all classes)
target = torch.from_numpy(target)
colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()

mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}

mask = torch.empty(h, w, dtype=torch.long)
for k in mapping:
    # Get all indices for current class
    idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
    validx = (idx.sum(0) == 3)  # Check that all channels match
    mask[validx] = torch.tensor(mapping[k], dtype=torch.long)
1 Like