RuntimeError: Found dtype Half but expected Float backward

i’m trying to train my model in PyTorch but I have the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-47-4c3349f66d9b> in <module>
     27             # Backpropagate loss
---> 28             loss.backward()
     30             # Steps

~/anaconda3/lib/python3.8/site-packages/torch/ in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    223     def register_hook(self, hook):

~/anaconda3/lib/python3.8/site-packages/torch/autograd/ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    128         retain_graph = create_graph
--> 130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
    132         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Found dtype Half but expected Float

My code:

    criterion_KL = torch.nn.KLDivLoss()

    logits_per_image, logits_per_text = model(, # both NxN matrix
    labels = labels_matrix(tax) # returns a NxN matrix, len(tax) = N

    loss_i = criterion_KL(logits_per_image, labels)
    loss_j = criterion_KL(logits_per_text, labels)
    loss = (loss_i + loss_j)/2

    print(loss_i, loss_j, loss)
   # tensor(-0.6089, device='cuda:0', grad_fn=<KlDivBackward>) tensor(-0.6089, device='cuda:0', grad_fn=<KlDivBackward>) tensor(-0.6089, device='cuda:0', grad_fn=<DivBackward0>)

    print(loss_i.dtype, loss_j.dtype, loss.dtype)
   # torch.float32 torch.float32 torch.float32

I have no idea why this error is happening. Tried to cast the labels tensor to float, but didn’t have results.
Any help?

Are you manually casting the model to half() or any parts of it?
Also, could you update to the latest PyTorch release (1.9.0), if you are using an older one? In case you are still hitting this issue in the latest release, could you post an executable code snippet to reproduce this issue?