Decoder always predicts the same token

I posted a separate issue about this in case anyone has an idea. Is it required that input and hidden for GRU have the same dtype? It seems that PyTorch should at least give a warning if the hidden state should be float32.