Pytorch custom loss.backward returning "RuntimeError: Found dtype Double but expected Float"

Hi, I am creating a custom loss function which has a kernel alignment loss and a reconstruction loss.

Kernel loss is calculated as a Frobenius distance between prior kernel and latent space code.

k_loss = torch.linalg.matrix_norm(torch.sub(code_K_norm,prior_K_norm), ord=‘fro’, dim=(- 2, - 1))

Reconstruction loss is calculated as:
reconstruct_loss = torch.nn.MSELoss()(encoder_inputs, dec_out)

dec_out is the output of encoder decoder network.

Now, I calculate the total loss as:

tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss

I have printed all the data types and they all seem to be in float_64.

TYPE OF PARAMS: <class ‘torch.Tensor’>
Total parameters: 243
RECONS LOSS: <class ‘torch.Tensor’>
RECONS LOSS: tensor(0.9852, dtype=torch.float64,
W_REG: <class ‘float’>
REG LOSS: <class ‘torch.Tensor’>
A_REG: <class ‘float’>
K_LOSS: <class ‘torch.Tensor’>
K_LOSS: tensor(1.0471, dtype=torch.float64, grad_fn=)
TYPE tot_loss: <class ‘torch.Tensor’>
tot_loss: tensor(1.1261, dtype=torch.float64, grad_fn=)

But, when I call the loss.backward() , it is throwing the error.

optimizer = torch.optim.Adam(model.parameters(),args.learning_rate)

File “/usr/local/lib/python3.7/dist-packages/torch/”, line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/usr/local/lib/python3.7/dist-packages/torch/autograd/”, line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Double but expected Float

Before passing the encoder inputs, I am converting each of them to float tensors.

Kindly take a look on this.

I can see,

RECONS LOSS: tensor(0.9852, dtype=torch.float64, grad_fn=)
K_LOSS: tensor(1.0471, dtype=torch.float64, grad_fn=)
tot_loss: tensor(1.1261, dtype=torch.float64, grad_fn=)

The grd_fn for the recons loss is “MseLossBackward0”, for K_loss it is “CopyBackwards” and for tot_loss it is “AddBackward0”. Also, the backward prop works fine if I do a k_loss.backward() but fails with both recons_loss and tot_loss although the data type is still float for all of them.

Issue resolved after changing the loss function from nn.MSELoss to calculating it without using any library just by using torch.mean(target - output) **2