Training Semantic Segmentation

Hi,

I am trying to reproduce PSPNet using PyTorch and this is my first time creating a semantic segmentation model. I understand that for image classification model, we have RGB input = [h,w,3] and label or ground truth = [h,w,n_classes]. We then use the trained model to create output then compute loss. For example, output = model(input); loss = criterion(output, label).

However, in semantic segmentation (I am using ADE20K datasets), we have input = [h,w,3] and label = [h,w,3] and we will then encode the label to [h,w,1]. ADE20K has a total of 19 classes, so out model will output [h,w,19]. I am confused how can we then compute for the loss as the dimension of the label and the output are clearly different.

Any help or guidance on this will be greatly appreciated!

1 Like

Since PSPNet uses convolutions, you should pass your input as [batch_size, channels height, width] (channels-first).

It looks like your targets are RGB images, where each color encodes a specific class.
If that’s the case, you should map the colors to class indices.
I’m not familiar with the ADE20K dataset, but you might find a mapping between the colors and class indices somwhere online.
If not, you can just create your own mapping, e.g. using a dict and transform the targets.

I mapped the target RGB into a single channel uint16 images where the values of the pixels indicate the classes. The formula is ObjectClassMasks = (uint16(R)/10)*256+uint16(G) where R is the red channel and G is the green channel. I don’t think there is a way to convert that into an image with [n_classes height width].

Also, can you provide more information on how to create my own mapping?

Thank you very much

Is the formula used for the color - class mapping?

Here is an example how to create your own mapping:

import torch

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Create dummy target image
nb_classes = 19 - 1 # 18 classes + background
idx = np.linspace(0., 1., nb_classes)
cmap = matplotlib.cm.get_cmap('viridis')
rgb = cmap(idx, bytes=True)[:, :3]  # Remove alpha value

h, w = 190, 100
rgb = rgb.repeat(1000, 0)
target = np.zeros((h*w, 3), dtype=np.uint8)
target[:rgb.shape[0]] = rgb
target = target.reshape(h, w, 3)

plt.imshow(target) # Each class in 10 rows

# Create mapping
# Get color codes for dataset (maybe you would have to use more than a single
# image, if it doesn't contain all classes)
target = torch.from_numpy(target)
colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()

mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}

mask = torch.empty(h, w, dtype=torch.long)
for k in mapping:
    # Get all indices for current class
    idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
    validx = (idx.sum(0) == 3)  # Check that all channels match
    mask[validx] = torch.tensor(mapping[k], dtype=torch.long)
9 Likes

Hi,
I am trying really hard to convert the tensor I obtained after training the model to the mask image as mentioned in this question.
But before that, I am finding the below code hard to understand-

for k in mapping:
    # Get all indices for current class
    idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
    validx = (idx.sum(0) == 3)  # Check that all channels match
    mask[validx] = torch.tensor(mapping[k], dtype=torch.long)

I am really not understanding what’s happening here.Could you please help me out?

1 Like

This dummy code maps some color codes to class indices.
E.g. the color blue represented as [0, 0, 255] in RGB could be mapped to class index 0.

1 Like

I run this code,but I get the size of mask is[190,100].Should I get the [18,190,100] size? What should I do?

Thanks a lot for all you answers, they always offer a great help.
I’m trying to do the same here. I have RGB images as my labels and I need to create the color-class mapping, but I was wondering, how can I know exactly the number of classes?
I’m working with Satellite images and the labels are masks for vegetation index values.

This line of code should return all unique colors:

colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()

and the length of this tensor would give you the number of classes for this target tensor.
Note that you would have to use multiple targets, if this particular target doesn’t contain all classes.

2 Likes

In case, we have multiple targets, does this mean should the custom dataset class return different target? or just for mapping we need to use different target?

Sorry for not being clear on the description. I meant that you would have to make sure to use a target tensor, which would contain all possible classes, to create the mapping.
E.g. in case you are only checking the unique colors of a single sample, this target tensor could potentially only contain a class subset of all available classes in the dataset.
In that case, you can stack the target tensors to create a “dataset target” tensor and check the unique colors on it.

1 Like

@ptrblck
hello, i have mapped my rgb mask of 12 classes to index. my target size is (batch, classes, h, w). how can i use it with cross entropy loss. thank you

For the posted model output shape your target should have the shape [batch_size, h, w] and contain class indices in [0, nb_classes-1] to be able to use nn.CrossEntropyLoss for a multi-class segmentation.

1 Like

how can i covert my data. my target data is (12, 256, 256) in custom datasetclass and (16, 12, 256,256) when calling pytorch data loader. what changes should i do in my customdataclass. i have already did one hot encoding for 12 classes.

One-hot encoded targets are wrong if you are using nn.CrossEntropyLoss, as the class indices are expected. In case you have already created the one-hot encoded targets, use target = torch.argmax(target, dim=1) to create the expected target.

1 Like

thank you. now my targets are (16, 256, 256) and the predictions are (16, 12, 256, 256). can i apply cross entropy loss function on these two data? no need for same dimension ?
edit: is there no need for one hot encoding the labels? can i directly feed the single channel mask (convert color to grayscale) to cross entropy loss function?

Yes, these shapes should work.

No, nn.CrossEntropyLoss expects the target to contain class indices as described in the docs, not one-hot encoded tensors.

No, you should feed the target containing class indices to this loss function. A grayscale mask would have a channel dimension with the size 1, which would not be accepted by nn.CrossEntropyLoss.

These are the right shapes.

1 Like

thank you very much for your support

hello again,

my validation accuracy, iu score and dice score is very less compared to my training accuracy. Why is this? I even tried shuffling the data.
My results:

training loss: 0.11447336673736572, iou score: 0.7451344195415179, dice score: 0.8161399212012007
validation loss: 3.3190221786499023, iou score: 0.09056307559778348, dice score: 0.11742234666033746

edit: my code is:

for batch_idx, (data, labels) in enumerate(train_loader):
   labels = torch.argmax(labels, dim=1)# torch.Size([32, 256, 256]) 32 batch size
   # i convert 32x32x256x256 (32 class one hot) to 32x256x256 single class (since pytorch cross entropy loss does auto encoding)
   data = data.to(device=device)
   labels = labels.to(device=device)
   #forward
   predictions = model(data) #prediction is 32x32x256x256(32 batch and 32 classes)
   loss = loss_fn(predictions, labels)

no need to use softmax before applying cross entropy loss? Please see if my training is correct.
thank you

my train and val scores are

training loss: 0.11447336673736572, iou score: 0.7451344195415179, dice score: 0.8161399212012007
validation loss: 3.3190221786499023, iou score: 0.09056307559778348, dice score: 0.11742234666033746

please help me why i get very less validation score. i am using same function to evaluate my train and test dataset. But validation scores is very less. Is this training problem?

my testing code is as follows:

model.eval()
with torch.no_grad():
  for batch_idx, (data, labels) in enumerate(train_loader):
    labels = torch.argmax(labels, dim=1)  # ccel torch.Size([16, 256, 256]) torch.int64.... 4,4,4......21,21,21...
    inputs1 = data.to(device=device)
    labels1 = labels.to(device=device)


    # Predicting segmentation for val inputs
    outputs1 = model(inputs1)#32, 32, 256, 256
    #print(outputs1.shape, labels1.shape)
    # Compute CE loss and aggregate it
    loss1 = loss_fn(outputs1, labels1)
    train_running_loss += loss1

    # Reshaping prediction segmentations and actual segmentations for iou and dice score
    preds = torch.argmax(outputs1, dim=1).detach().cpu().numpy()
    gt = labels1.detach().cpu().numpy()
    print(preds.shape, gt.shape)#(32, 256, 256) b c h

    # Compute confusion matrix
    conf_mat = confusion_matrix(y_pred=preds.flatten(), y_true=gt.flatten(), labels=list(range(21)))
    #conf_mat = confusion_matrix(y_pred=gt.flatten(), y_true=gt.flatten(), labels=list(range(21)))

    # Computing iou and dice scores and aggregating them
    iou_score = get_mean_iou(conf_mat=conf_mat)
    iou_running_score += iou_score
    dice_score = get_mean_iou(conf_mat=conf_mat, multiplier=2.0)
    dice_running_score += dice_score

  # Averaging loss and scores
  avg_train_loss = float(train_running_loss)/(batch_idx+1)
  avg_iou_score = float(iou_running_score)/(batch_idx+1)
  avg_dice_score = float(dice_running_score)/(batch_idx+1)

  # Visualizations for batch wise metrics
  print('training loss: {}, iou score: {}, dice score: {}'.format(avg_train_loss, avg_iou_score, avg_dice_score))

best_loss = 1000000000
val_running_loss = 0
iou_running_score = 0
dice_running_score = 0

model.eval()
with torch.no_grad():
  for batch_idx, (data, labels) in enumerate(test_loader):
    labels = torch.argmax(labels, dim=1)  # ccel torch.Size([16, 256, 256]) torch.int64.... 4,4,4......21,21,21...
    inputs1 = data.to(device=device)
    labels1 = labels.to(device=device)


    # Predicting segmentation for val inputs
    outputs1 = model(inputs1)
    #print(outputs1.shape, labels1.shape)
    # Compute CE loss and aggregate it
    loss1 = loss_fn(outputs1, labels1)
    val_running_loss += loss1

    # Reshaping prediction segmentations and actual segmentations for iou and dice score
    preds = torch.argmax(outputs1, dim=1).detach().cpu().numpy()
    gt = labels1.detach().cpu().numpy()
    #print(preds.shape, gt.shape)

    # Compute confusion matrix
    conf_mat = confusion_matrix(y_pred=preds.flatten(), y_true=gt.flatten(), labels=list(range(21)))
    #conf_mat = confusion_matrix(y_pred=gt.flatten(), y_true=gt.flatten(), labels=list(range(21)))

    # Computing iou and dice scores and aggregating them
    iou_score = get_mean_iou(conf_mat=conf_mat)
    iou_running_score += iou_score
    dice_score = get_mean_iou(conf_mat=conf_mat, multiplier=2.0)
    dice_running_score += dice_score

  # Averaging loss and scores
  avg_val_loss = float(val_running_loss)/(batch_idx+1)
  avg_iou_score = float(iou_running_score)/(batch_idx+1)
  avg_dice_score = float(dice_running_score)/(batch_idx+1)

  # Visualizations for batch wise metrics
  print('validation loss: {}, iou score: {}, dice score: {}'.format(avg_val_loss, avg_iou_score, avg_dice_score))