I had a question regarding binary segmentation as I’m a little confused.
I’m try to implement binary segmentation using semi-sueprvised learning. Each batch has an image [64,2,64,64], and a label [64,1,64,64] which can be binary if labaled or a tensor of -1s if unlaballed.
When I do a prediction, and I’m using BCEWithLogitsLoss with an ignore index of -1. My question is, when passing the prediction to the loss function, should I argmax it or sigmoid it with a threshold or just slice to [:,1,:,:] or should i directly pass it to the loss function along with the labels?
Sorry I forgot to clarify, I’m trying to predict the segmentation mask, which i labale as 1 and everything else as 0.
Passing the logits directly to the loss function is the expected behavior, but note that if you are using BCEWithLogitsLoss anyway you should just resize the output layer to a single “channel.” Otherwise if you cannot change the output shape you should use CrossEntropyLoss as you have a “channel” for each class.
Either way I would not expect an accuracy difference between the approaches.
What I did is I sliced the prediction output taking the second channel preds[:,1,:,:] and used it in BCEWithLogitsloss. However, I’m not sure if that is correct, I assumed the second channel corresponds to the second class which I’m predicting.
If you do that during training then you are effectively throwing out half of the results of the final layer, although the model should be able “learn around” the strange quirk and ultimately produce results that should be similar to using all of the channels + CrossEntropyLoss. If you do that on a pretrained model without training then it would be incorrect as you would considering the logits of a class without any relation to the other class. As you are considering loss functions here I assume you are doing training which means the model would “work” although with wasted computation being done in a nonstandard way.
Indexing one out of two channels + BCE (strange, but “works” if trained this way)
Last layer with single output channel + BCE (expected standard practice)
Last layer with two output channels + CE (should be equivalent to former but somewhat nonstandard)