Input data for Multi Class Segmentation

I am building a multi class image segmentation model. In each image there could be 11 classes.

In mask image, pixels are marked with class no’s of category they belong to. So,
Image shape - [3, 224, 224]
Mask shape - ?

How should masks be shaped?

For a multi-class segmentation you could use nn.CrossEntropyLoss with a model output in the shape [batch_size, nb_classes, height, width] and a target containing class indices in the range [0, nb_classes-1] in the shape [batch_size, height, width]:

criterion = nn.CrossEntropyLoss()
batch_size, nb_classes, height, width = 2, 10, 24, 24
output = torch.randn(batch_size, nb_classes, height, width, requires_grad=True)
target = torch.randint(0, nb_classes, (batch_size, height, width))

loss = criterion(output, target)
1 Like

Thanks for your immediate response @ptrblck.

So, my target has this shape,
Target = [nb_classes, height, width]

Batch size gets added later by the Dataloader.

Is this ok?

No, check my code which shows that the target does not have a class dimension. You can copy/paste my code as it’s also executable.

Sorry for not explaining completely.

My target image is already a single channel image and containing pixel values based on which class they belong to.
Each pixel can belong to only 1 class.
So, target is,

  • grayscale
  • shape : [224, 224]

There are 10 possible classes including background

In this case the target would have the unbatched shape [height, width] and the batched shape [batch_size, height, width], which is correct.

The model is training now. Thanks

I am also trying to create a VIT with the same dataset. This model is not returning correct shape of output - [2, 10, 224, 224] - [B, C, H, W]