My implementation doesn’t work as BCEWithLogitsLoss, and I’m not sure you need it.
In my approach the logits or probabilities will be calculated over all classes:
print(torch.exp(output[0, :, 0, 0]))
> tensor([ 0.2759, 0.3390, 0.0344, 0.3208, 0.0299])
print(torch.exp(output[0, :, 0, 0]).sum())
> tensor(1.0)
If you use BCEWithLogitsLoss, you will apply a nn.Sigmoid on your output.
output = F.sigmoid(x)
print(output[0, :, 0, 0])
> tensor([ 0.7093, 0.7499, 0.2333, 0.7394, 0.2091])
Do you need this kind of probabilities?
Using this approach you could predict all classes above a thrashold of e.g. 0.5 to be in the pixel position.
In my example classes 0, 1 and 3 would be in the first pixel position.