What loss function for binary unet?

Hi, I’m trying to do a semantic segmentation for a binary case (i.e. background=0, and class-of-interest=1) using unet. my last layer of deconvolution looks as follow:

self.final2 = nn.Conv2d(num_filters, num_classes=1, kernel_size=1)

My output from the model has one channel (i.e. output(N,1,256,256)), in which higher probability means more like class 1, and background otherwise. I was wondering:

1- what loss function should I use?
2- Is my way of calculating the confusion matrix below is correct?

probability_class = torch.exp(output)
prediction = torch.argmax(probability_class, 1)
tn, fp, fn, tp = confusion_matrix(gt.view(-1), prediction.view(-1)).ravel()

While your output has only one channel, use probability_class = torch.sigmoid(output) to norm them between [0, 1] and then prediction = probability > threshold may be a better choice. About the Loss function, Sigmoid + MSELoss is OK.

Note that output has one channel, so probability_class will also has only one channel, that means your code torch.argmax(probability_class, 1) will always return zero.

Or your can use self.final2 = nn.Conv2d(num_filters, num_classes=2, kernel_size=1) and nn.CrossEntropyLoss. In this case, direct argmax (eval mode) is OK.

1 Like

Thanks @Eta_C. If I choose num_classes=2 and use nn.CrossEntropyLoss, does argmax at eval mode give me indices 0 and 1? I am asking because my groundtruth data is 0 and 1 and I want my prediction output to match.

Expected YES.

Also, as a semantic segmentation task, there are some other loss can be tried. For example, Focal Loss, DICE Loss, IoU Loss, Tversky Loss, Lovasz Loss, and so on.

1 Like