[solved] RNN - access hidden states of every time step

I’m training a vanilla RNN on MNIST (sequential, len=784, one scalar value at each time step), and I would like to visualize the hidden states at every time step, not just the final time step. How can I access the hidden states prior to the one at the final time step? It seems that torch.nn.RNN api only allows me to access the hidden state at the final time step.

I can probably accomplish this by using torch.nn.RNNCell instead of torch.nn.RNN: looping through the input sequence manually and saving all the hidden states; however, I’m a bit concerned about the performance drop if I implement an RNN using RNNCell instead of RNN, both in terms of speed (see here) and classification accuracy (see here).

Any suggestions/advice would be greatly appreciated!

When you run a rnn net over a sequence of data, the output has two tensors: (output, h_n). While h_n is the hidden state from the last timestamp, the tensor output actually has the hidden states for every timestamps in the sequence.

rnn_net = nn.RNN(4, 3)
X = torch.rand((6,5,4))
out, hidden = rnn_net(X)
out.shape
hidden.shape

You will see the shape of out is torch.Size([6, 5, 3]) which represents (seq_len, batch, hidden_size)

4 Likes

Thanks for pointing that out! Looks like I have been very sloppy when reading documentation…

But output contains hidden states from only the last LSTM layer. What to do in case we want hidden states from all the layers of each time step?

Looping makes it a lot slower!

2 Likes