In multi-classes segmentation,should the output channel be C or C+1?

For example,in a dataset which contains cats and dogs,
In C==1 case,the requirements is:
Just segmenting out cats,treating all others as background(of course target will just label cats as 1,set all other pixels to 0)
In this case,should the output channel be 1 or 2?
In my opinion,it can be either 1 or 2,if setting output channel to 1,just use BCELossWithLogits,and if set it to 2,we must one-hot encoding targets firstly,then use BCELossWithLogits,
am I right about this first case?

In C>1 case,requirement is:
Segmenting out cats And Dogs
I this case should the output channel be 2 or 2+1(background),in other words,should be C or C+1(background) if C>1?
In my opinion,it must 3(C+1) that the background must be treated as a class,am I right about this second case?

For a binary classification use case, you can use either one or two output channels.
For a single channel output, you could use nn.BCEWithLogitsLoss.
If you are using two output channels, you could treat your use case as a

  • multi-class classification (only one valid class per pixel) and use nn.CrossEntropyLoss
  • or as a multi-label classfication (zero, one or more classes valid per pixel), in which case you would again use nn.BCEWithLogitsLoss

The background would also represent a class as answered in your other topic.

I got a binary segmentation problem working but am now modifying to be multi-class, labels[0,1,2…8]. I changed my loss function to crossentropyloss (previously BCEwithlogicloss), however I am now seeing error:

File “/home/user/miniconda3/envs/3dunet/lib/python3.7/site-packages/torch/nn/functional.py”, line 1848, in nll_loss
out_size, target.size()))
ValueError: Expected target size (2, 80, 160, 160), got torch.Size([2, 1, 80, 160, 160])

I am not sure where this extra dimension would be coming from, I also had to change the output of my network to 9 channels (C+1) as it was only 1 for the binary seg.

Any ideas? Thanks,

Kyle

Based on the expected target shape I assume you are dealing with outputs as volumes, where each voxel would be classifies into one of the 9 classes?
If that’s correct, you would have to remove dim1 via target = target.squeeze(1).
I don’t know where it’s coming from, but feel free to post the Dataset code (including shape information) so that we could have a look at it.

1 Like