Hi @KFrank
Firstly, I would like to appreciate your explanations and contributions. However, I would be glad for the affirmation of gained knowledge from the thread and likewise some clarifications.
- The difference between single label and multi-label multi-class segmentation is the presence overlapping - a case where particular corresponding pixel value(s) for each target classes is/are annotated
- The go loss function for single label multi-class is
CrossEntropyLosswhile for multi-label multi-class isBCEWithLoistLoss. In any of the tasks, the logits (the final layer of the network’s output) are fed to the loss function with the stacked_target both of dimension[nBatch, nClassess, width, height]. A combination with dice loss can also be experimented.
A typical instance of this
torch.tensor([[[1., 0., 0.],
[1., 0., 0.],
[0., 0., 0.]],
[[0., 1., 1.],
[0., 1., 0.],
[1., 1., 1.]],
[[0., 0., 1.],
[0., 1., 1.],
[1., 0., 1.]]])
Is this what you referred to as multi-hot encoded?
In addressing class imbalance, my approach is to take the summed average of the 1’s in each class as the class weight to be passed into my loss function i.e [2, 6, 5] in the case of stacked_one_hot_encoded_targets above.