Suppose that you have token logits from a Transformer model. The logits are of shape (batch_size, sequence_length)
. In the example below, (4,512)
.
I would like to turn every token logit into an independent Bernouilli distribution (as I’m implementing a research paper). Next, the loss is defined as the log probability of the ground-truth labels.
import torch
logits = torch.randn((4,512))
dist_per_token = torch.distributions.Bernoulli(logits=logits) # every token is now a Bernouilli distribution
labels = torch.randint(0, 1, (4,512))
loss_per_token = dist_per_token.log_prob(labels)
However, when my labels is a LongTensor, i get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-196b06f84a85> in <module>()
1 labels = torch.randint(0, 1, (4,512))
2
----> 3 dist_per_token.log_prob(labels)
1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
2824 raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
2825
-> 2826 return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
2827
2828
RuntimeError: result type Float can't be cast to the desired output type Long
Shouldn’t this work, as the Bernouilli is a discrete distribution? Or should one always provide FloatTensors to log_prob
? It only works when I type dist_per_token.log_prob(labels.float())
.