Bce loss when prediction is nearly to 1

Hi, I wish to use bceloss to calculate the prediction loss. But at the beginning of the training, the prediction is nearly about 1. Then as for the bceloss, it occurs some error. Looking forward to your help!
For a toy example:

import torch
import torch.nn as nn
a = torch.randn(512,4)
leakyrelu = nn.LeakyReLU(0.2)
att = nn.Softmax(dim=-1)(-leakyrelu(a))
b = torch.ones(512, 14541, 4)
score = torch.einsum('bnk,bk->bn', [b, att])
bce = nn.BCELoss()
label = torch.full((512, 14541), 3.141592)
bce(score, label)

image
errors are following:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-d7e1c60c672a> in <module>()
      8 bce = nn.BCELoss()
      9 label = torch.full((512, 14541), 3.141592)
---> 10 bce(score, label)

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2760         weight = weight.expand(new_size)
   2761 
-> 2762     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
   2763 
   2764 

RuntimeError: all elements of input should be between 0 and 1

Actually, in my empirical experiments, the error is Reduce failed to synchronize: device-side assert triggered. Then I make a toy experiment above, it also seems strange.

Hi Wjk!

The short answer is that score.max() > 1.0.

Because you calculate Softmax(), the last dimension of att will
sum to one, up to floating-point round-off error.

You “contract” the last dimension of att with torch.ones(). That is,
you multiply the elements of att by 1.0 and then sum along the
last dimension. Therefore the elements of score are all 1.0, up
to floating-point round-off error.

By default, pytorch prints out tensor values to 5 digits, so when
you print out score.max() you only see 1.0000. Print out
score.max() - 1.0, and you will see a small number that is
slightly greater than zero.

Here, you pass into BCELoss values that are (slightly, due to round-off
error) greater than one, so BCELoss raises the exception you posted:

Best.

K. Frank

So I need to control it before bceloss, is it right? Such as

score[score > 1.0] = 1.0

Hi Wjk!

Yes, this is a sensible thing to do.

This could also be an appropriate use case for torch.clamp().

Note, however, you should only be doing this if the result you are forcing
to be <= 1.0 should already be mathematically no greater than 1.0, and
only exceed 1.0 because of round-off error. (as in your toy example). If in
your real use case the values you use as input for BCELoss could,
absent round-off error, be greater than 1.0, then you should fix that issue.
Simply clamping them to be no greater than 1.0 would just be sweeping
the real problem under the rug.

Best.

K. Frank

1 Like

Thanks a lot. I have thought clearly.
Best,
Json