Image segmentation related question

I am training an encoder-decoder network for semantic segmentation from single RGB image.

The network is being trained to detect 40 classes for the semantic segmentation.
The ground_truth that I have has the dimensions - [batch_size, 256, 256] where each pixel in the 256x256 matrix corresponds to a class ID between 0-39.
The network outputs a [batch_size, 40, 256, 256] tensor.

I am using a CrossEntropy loss to train the same.

semantic_loss = torch.nn.CrossEntropyLoss()
loss = semantic_loss(output, ground_truth)

The code is working fine and as the network is training, the loss is going down.
I wanted to know two things -

  1. Is the loss calculation approach correct?
  2. How can I visualise the output results (which have shape [bs, 40, 256, 256]) ?
  1. Yes, everything looks good as long as your model outputs logits (i.e. no non-linearity as the last layer).

  2. You could get the predictions using pred = torch.argmax(output, dim=1) and visualize them e.g. with matplotlib.

1 Like