Multiclass Segmentation

I assume you have already found suitable code snippets for a binary segmentation use case?
If so, you could use it as a base line and make a few changes for a multi class segmentation use case:

  • use nn.CrossEntropyLoss as your criterion
  • your model should output a tensor with the shape [batch_size, nb_classes, height ,width]
  • the target should be a LongTensor with the shape [batch_size, height, width] and contain the class indices for each pixel location in the range [0, nb_classes-1]

Depending on the format of your segmentation mask images, you might need to create a mapping e.g. between color codes and the corresponding class indices.

Let us know, if and where you get stuck.

16 Likes