Dataset class with multiple Inputs/Outputs images

It’s a segmentation problem with multiple classes: 1 to 3 images as input depending on the experiment, and up to 10 classes to segment.
We could simplify the problem in single class segmentation, but we also would like to try to detect them all at once.

The number of input images changes for every sample?
I.e. sometimes the current sample has only 1 input image and other times it has 3?

Usually all classes of a segmentation are saved in one segmentation map (similar to an image), so that we could possibly return just one segmentation target containing all 10 classes.

  • No, the number of input images will be fixed, so we’ll develop one model per number of input.
  • We would like to keep them separated. Is that doable? I used to work with Keras and it’s doable.

What would be the advantage of having 10 different targets?
Could you explain, how you’ve processed the targets using Keras?

  • Some outputs might have overlaps
  • Keras considers multiple inputs/outputs as a list, so it’s easy to develop such models.

If the outputs have overlaps, you’re dealing with a multi-label classification use case, i.e. each pixel might belong to more than one class?
If so, you could still use a single segmentation map and just return multi-label predictions.

I’m still not understanding, how Keras deals with this problem.
For example, if one pixel belongs to class1 and class2 at the same time. You would force your model to learn it as class1 in one iteration and class2 in another one?
Would your model be able to predict more than one class for each pixel then?

I’m digging a bit deeper, as your use case is most likely possible to implement, I just have my doubts about the correctness.

What happens in the case if the classes overlap ? Multi-label for each pixel. How do you prepare the mask for that ? I have 4 separate masks corresponding to same input image. How should I combine them as each pixel can belong to 1 or more classes ?

In this case, each channel would correspond to a specific class.
A 1 at a particular pixel position would represent the occurrence of the corresponding class, while a 0 would be used for its absence.
Also, for this use case you could use nn.BCE(WithLogits)Loss as your criterion.

1 Like

Thats exactly what I did. So my labels are now [batch, 4, 224, 224]. The training is fine but I’m not able to calculate IoU having 4 separate masks. Any help how that can be achieved so that I can get per class IoU ?

This post with @tom’s implementation might be useful.

Thanks, I’ll go through that. However, will i need to calculate IoU per class by splitting the labels into separate classes ?

You would most likely have to process each class channel separately with its corresponding output channel.

Does that mean my predictions will also have dimension like [batch, 4, 224, 224] ? How would I turn my logits into predictions ? If I use argmax along the channel dimension, I will only get 1 prediction. Do I run the logits through a Sigmoid layer and then threshold using pred = logits > 0.5 ?

Yes, the output shape of your model would be the same as the targets in this use case.
To get the predictions you could apply preds = logits > 0. or alternatively preds = torch.sigmoid(logits) > 0.5, which would yield the same result.

2 Likes

Hi @ptrblck, thanks for your explanation above. I am also working on training an image segmentation model with 5 classes where a part of 1-2 classes overlap with each other. See this image for example: image
Here the left is the original Grayscale image and right is the segmentation map. The green portion remains the same across different structures, but at some places, the green segment might be overlapped a bit with the blue one. For example, see the middle bottom segmentation. In terms of annotation, the green annotation is done for an area first and then the blue annotation is done on top of that i.e. the pixels at that location belong to two classes.
So, how do I represent the dataset and train the model for this scenario ? Any suggestions ? I am able to stack the segmentation belonging to different classes across channels at the input but how to deal with the overlap during training ? Will the model be able to learn that overlap ?
Thanks :slight_smile:

You can stack the one-hot encoded segmentation maps to create a target in the shape [batch_size, nb_classes, height, width], where the nb_classes channel indicates if the current class is active or not by using ones and zeros, repsectively.
E.g. assuming the blue class represents class0 while the green represents class1, the target would have ones in channel0 and channel1 for these pixel locations.

To train such a multi-label segmentation you could use nn.BCEWithLogtisLoss as the criterion.

1 Like

Thanks a lot @ptrblck. How does this work when two encodings overlap over a small area ? Will the model be able to differentiate between those overlapping labels if they are one hot encoded ? Btw I am using a DeepLab v3 model for this task. Also looking at the UNet as well. Thanks :slight_smile:

Since each class channel contains a separate one-hot encoded target map, these particular pixel will have multiple active classes in the class dimension.
The complete target is multi-hot encoded.

1 Like

Hi @ptrblck, thanks for your guidance. I got the mask one-hot encoded i.e. binary image with each label across channels. So, currently my mask shape looks like [height, width, n_labels]. Now while training the model, this should work fine, but how do I get the output with the color coding ? Since input is binary i.e. 1’s and 0’s, the output will have n_labels channels but still the output will be binary only i.e. 1’s and 0’s.
So, how to get the color coded output ? Do I need to use an bitwise_and or bitwise_or with channels and colors on the output to get the mask as shown above in output ?
Thanks :slight_smile:

Your target should have the class dimension in dim1, such that its shape is [batch_size, nb_classes, height, width] for a multi-label segmentation.

To create color codes, you could map the predicted classes back to color codes. However, it depends on your use case how you would like to visualize overlapping classes. E.g. you could plot the “highest/lowest” class or even mix the colors.