RuntimeError: Found dtype Double but expected Float during (loss.backward())

I encounter a similar issue today but the weird part is that when using l1_loss everything works well but after I switch to mse_loss, the RuntimeError happens.