LSTM Hidden State Changing Dimensions Error

I seem to be having an odd issue when using and LSTM. I pass in an initial hidden state of size (3,1,3) but pytorch throws an error saying it is only (1,3) and expected (3,1,3) even though I can verify before passing it in that the Tensor is (3,1,3).

class AdRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(AdRNN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.LSTM(input_size=input_size, num_layers=num_layers, hidden_size=hidden_size, batch_first=True)

    def forward(self, x):
        h_0 = torch.rand(self.num_layers, x.size(0), self.hidden_size)
        nn.init.xavier_normal_(h_0)

        print(h_0.shape)
        
        output, _ = self.rnn(x, h_0)
        return output

Note that I am printing my h_0 shape for debugging. It comes back as (3,1,3) but Pytorch raises the following error in the forward method when passing in h_0 to the LSTM:

RuntimeError: Expected hidden[0] size (3, 1, 3), got (1, 3)

If I change the rnn type to GRU or vanilla RNN in init everything works just fine but LSTM is being cranky. I am using using PyTorch 0.4.1. I can of course just not pass in h_0 and have the LSTM hidden init to zeros but I would rather not. Any thoughts?

2 Likes

LSTM takes a tuple of hidden states:

self.rnn(x, (h_0, c_0))

it looks like you haven’t sent in the second hidden state?

3 Likes

Aha! I didn’t realize the argument was a tuple for LSTM, I was thinking it was (x, h_0, c_0). So you have to provide either both the hidden and the cell state or none at all then (or I suppose pass in a zero tensor for one yourself if you want to just initialize the other).

That did it, thank you!

1 Like

Maybe we should have a better API for this, you’re right in that there is no way to just pass one state and not initialize the other one.

1 Like

just pass the (h_0, c_0) in your rnn cell :
input_h, (input_h_t, input_c_t) = self.rnn(input_emb, (h_0, c_0))
:slight_smile: