Is it required that input and hidden for GRU have the same dtype (float32)?

I had an issue with a project of mine where a GRU failed to optimize and always predicted the same token. After a lot of debugging I found out that the input was of dtype float and the initial given state that I passed was a double. This caused the outputs and hidden states to explode, which in turn seemed to cause issues for the optimizer and eventually the model would always predict the same token.

My question is, why are different dtypes an issue? Why does it lead to an explosion in the GRU? If different dtypes are problematic, perhaps PyTorch should throw a warning?

I tried to reproduce this issue with a small example, but if I try to pass the initial hidden state as a DoubleTensor, I get an error:

torch.manual_seed(2809)
rnn = nn.GRU(10, 20, 2)
for _ in range(100):
    input = torch.randn(5, 3, 10)
    h0 = torch.randn(2, 3, 20).double()
    output, hn = rnn(input, h0)
    print(output.abs().max(), hn.abs().max())

> RuntimeError: expected scalar type Float but found Double

Could you share a code snippet so that we could have a look?

After testing your script, it seems that this only occurs when using the GPU. PyTorch will throw the warning that you mentioned when running on CPU, but it will just run through the code when on GPU - with very large or very small values as output. You can try the snippet below. Switching the flag USE_GPU should show the issue that I experience. Thanks for your time!

import torch
from torch import nn

USE_GPU = True

torch.manual_seed(2809)
torch.cuda.manual_seed(2809)


device = torch.device("cuda" if USE_GPU else "cpu")
rnn = nn.GRU(10, 10).to(device)
for _ in range(100):
    input = torch.randn(1, 3, 10).to(device)
    h0 = torch.randn(1, 3, 10).double().to(device)
    print("input dtype", input.dtype)
    print("h0 dtype", h0.dtype)
    output, hn = rnn(input, h0)
    print(output.abs().max(), hn.abs().max())
    print()

Thanks for the code snippet. Which PyTorch version are you using, as I’m still getting the dtype mismatch error in 1.7.0.dev20200830.

I’m om 1.6 + CUDA10.1, and the problem occurs on both Windows and Linux.

Could you try the latest nightly and verify, that it’s fixed, please?

I can only test on Linux at the moment, but yes - here the issue is fixed. However, the error message differs between CPU and GPU.

CPU: RuntimeError: expected scalar type Float but found Double
GPU: RuntimeError: Input and parameter tensors are not the same dtype, found input tensor with Double and parameter tensor with Float

This is probably not a huge issue, although it might be better to streamline the error messages between CUDA/other devices?

Thanks for the hint; it’s nice to see that users cannot run into this issue anymore in 1.7!

Yeah, it makes sense to have the same error message on CPU and the GPU. Would you mind creating an issue on GitHub so that we can track it?

Done. I’ll close this topic now, as it was a bug in 1.6. For future readers that are still on a lower version: you may want to call .float() on your tensors before passing them to your GRU if you are experiencing issues.

1 Like