Semantic Segmentation Loss Function & Data Format Help

Hi there, I was wondering if somebody could help me with semantic segmentation.

I am using the segmentation_models_pytorch library to train a Unet on the VOC2012 dataset.

I have not trained semantic segmentation models before, so I am not sure what form my data should be in. Specifically, I am not sure what loss function to use, and what format my data needs to be in to go into that loss function.

So far:
The input to my network is a bunch of images in the form:
[B, C, H, W]

This is currently [12, 3, 512, 512]

The output from my network is in the form:
[B, NUMCLASSES, H, W]

This is currently [12, 20, 512, 512]

I am loading in my normal images and segmentation truths using PIL.
So my input images have 3 channels, and my segmentation images have 1 channel (in PIL’s P mode).

So my input image goes into the network and outputs a shape of [12, 20, 512, 512]
My ground truth images are in the shape of [12, 1, 512, 512]

My question:
What loss function should I use, and what format / shapes should my data be in?
Is is possible to just shove my outputs into a loss function as they are and calculate a loss, or do I need to reshape them in any way?
Do I have to calculate an output prediction with output.argmax(1) for input into a loss function?

Thanks for your time.

The shapes look almost right. For a multi-class segmentation use case you could use nn.CrossEntropyLoss as the criterion, which expects the model output to contain logits in the shape [batch_size, nb_classes, height, width]. The target should have the shape [batch_size, height, width] (remove dim1 in your script via target = target.squeeze(1)) and should contain the class indices in the range [0, nb_classes-1].

Assuming you are dealing with 20 classes, here is a small code example:

output = torch.randn(2, 20, 24, 24, requires_grad=True)
target = torch.randint(0, 20, (2, 24, 24))
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)

No, nn.CrossEntropyLoss expects the logits for each class. You could use torch.argmax(output, dim=1) to compute the predictions, where each pixel would contain the the predicted class index.