RuntimeError: result type Float can't be cast to the desired output type Long when using log_prob

Suppose that you have token logits from a Transformer model. The logits are of shape (batch_size, sequence_length). In the example below, (4,512).

I would like to turn every token logit into an independent Bernouilli distribution (as I’m implementing a research paper). Next, the loss is defined as the log probability of the ground-truth labels.

import torch

logits = torch.randn((4,512))

dist_per_token = torch.distributions.Bernoulli(logits=logits) # every token is now a Bernouilli distribution 

labels = torch.randint(0, 1, (4,512))
loss_per_token = dist_per_token.log_prob(labels)

However, when my labels is a LongTensor, i get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-196b06f84a85> in <module>()
      1 labels = torch.randint(0, 1, (4,512))
      2 
----> 3 dist_per_token.log_prob(labels)

1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2824         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2825 
-> 2826     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
   2827 
   2828 

RuntimeError: result type Float can't be cast to the desired output type Long

Shouldn’t this work, as the Bernouilli is a discrete distribution? Or should one always provide FloatTensors to log_prob? It only works when I type dist_per_token.log_prob(labels.float()).