# How to calculate loss for binary images?

Hi,

I am trying to use encoder-decoder architecture to predict binary images. What is the appropriate loss function for this ? Input is a binary image and output is a binary image & gt is a binary image.

Hi Vishal!

I assume that you want to use your loss function to train your network.
I also assume that “output” is the output of your network, that is, the

If output is, indeed, a binary image, e.g., an image where the pixel values
are either 0.0 and 1.0 (or, say, 0 and 255), then you won’t be able to
construct a suitable loss function.

This is because – in order to train your network with a gradient-based
optimization algorithm – your loss needs to be usefully differentiable.
Because your output is a binary image, it can’t be usefully differentiable.
Even if you define it in a way that is differentiable, its gradients will be
zero almost everywhere.

I would recommend that you use the output of your final Linear layer,
without any non-linear activation or anything that binarizes it, as your
prediction, and then use BCEWithLogitsLoss as your loss criterion
binary).

Best.

K. Frank

I was able to backpropagate through the “thresholding” layer by defining a custom autograd function. The output of the model has a pixel value of either 0 or 1. I trained this model using L1 loss, MSE loss and BCEwithLogits loss. None of these losses I was able to get desired output.

A better way would be to use a linear layer followed by a sigmoid output, and then train the model using BCE Loss. The sigmoid activation would make sure that the output stays between 0 and 1, and as the model trains, the output would move towards binary outputs. You can later use binary thresholding to get the final output.

class Decoder(nn.Module):

def __init__(self, input_dim=256, output_dim=3, n_ups=2):
super(Decoder, self).__init__()

model = []

ch = input_dim
for i in range(n_ups):
model += [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)]
model += [nn.Conv2d(ch, ch // 2, 3, 1, 1)]
model += [get_norm_layer(ch // 2)]
model += [nn.ReLU()]
ch = ch // 2
out_channels = 1
model += [nn.Conv2d(ch, out_channels, 3, 1, 1)]
# model += [nn.Tanh()]
# # model += [nn.Sigmoid()]
model += [nn.Linear(512, 512)]
model += [nn.Sigmoid()]
self.model = nn.Sequential(*model)

def forward(self, x):
out = self.model(x)
# print("[INFO] out shape : ", out.shape)
# exit()
return out

This is the decoder I have and I am calculating the BCELoss with reduction as the mean. The output images are not at all close to the ground truth.

For example the ground-truth looks like this -

I believe that no loss can match such kind of ground truth.

Hi Vishal!

If your predictions – the output of your model – are binary (presumably
because they are binarized by your “thresholding” layer), you will get
zero gradients. You then try to fix this up by using a custom autograd
function that gives you some sort of non-zero gradients.

After thresholding, the damage has largely been done, and trying to
repair that damage with your choice of loss function is not likely to work.

The standard approach – that works quite well in many use cases – is
to train with the output of your final Linear layer as probabilistic logit
predictions (with no “thresholding” layer or other non-linear activation),
and use BCEWithLogitsLoss to compute the loss between your logit
predictions and your binary ground truth.

Then, at inference time, if you want to visualize you predicted images
images by thresholding the logits against zero, e.g.,

binary_predicted_image = (predicted_image > 0).long()

Try this first, and see whether you get good results (or at least better
results than you get from your other schemes).

Best.

K. Frank

Hi Mohammed!

Sigmoid feeding into BCELoss is mathematically the same as
BCEWithLogitsLoss, but is numerically less stable.

BCELoss really shouldn’t be used, and feeding the output of the final
Linear layer directly into BCEWithLogitsLoss (without any Sigmoid
or other intervening non-linear activation) is the way to go.

Best.

K. Frank