Hi, I wish to use bceloss to calculate the prediction loss. But at the beginning of the training, the prediction is nearly about 1. Then as for the bceloss, it occurs some error. Looking forward to your help!
For a toy example:
import torch
import torch.nn as nn
a = torch.randn(512,4)
leakyrelu = nn.LeakyReLU(0.2)
att = nn.Softmax(dim=-1)(-leakyrelu(a))
b = torch.ones(512, 14541, 4)
score = torch.einsum('bnk,bk->bn', [b, att])
bce = nn.BCELoss()
label = torch.full((512, 14541), 3.141592)
bce(score, label)
errors are following:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-d7e1c60c672a> in <module>()
8 bce = nn.BCELoss()
9 label = torch.full((512, 14541), 3.141592)
---> 10 bce(score, label)
2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
2760 weight = weight.expand(new_size)
2761
-> 2762 return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
2763
2764
RuntimeError: all elements of input should be between 0 and 1
Actually, in my empirical experiments, the error is Reduce failed to synchronize: device-side assert triggered. Then I make a toy experiment above, it also seems strange.
Because you calculate Softmax(), the last dimension of att will
sum to one, up to floating-point round-off error.
You “contract” the last dimension of att with torch.ones(). That is,
you multiply the elements of att by 1.0 and then sum along the
last dimension. Therefore the elements of score are all 1.0, up
to floating-point round-off error.
By default, pytorch prints out tensor values to 5 digits, so when
you print out score.max() you only see 1.0000. Print out score.max() - 1.0, and you will see a small number that is
slightly greater than zero.
Here, you pass into BCELoss values that are (slightly, due to round-off
error) greater than one, so BCELoss raises the exception you posted:
This could also be an appropriate use case for torch.clamp().
Note, however, you should only be doing this if the result you are forcing
to be <= 1.0 should already be mathematically no greater than 1.0, and
only exceed 1.0 because of round-off error. (as in your toy example). If in
your real use case the values you use as input for BCELoss could,
absent round-off error, be greater than 1.0, then you should fix that issue.
Simply clamping them to be no greater than 1.0 would just be sweeping
the real problem under the rug.