[solved] RNN - access hidden states of every time step


(Moonlightlane) #1

I’m training a vanilla RNN on MNIST (sequential, len=784, one scalar value at each time step), and I would like to visualize the hidden states at every time step, not just the final time step. How can I access the hidden states prior to the one at the final time step? It seems that torch.nn.RNN api only allows me to access the hidden state at the final time step.

I can probably accomplish this by using torch.nn.RNNCell instead of torch.nn.RNN: looping through the input sequence manually and saving all the hidden states; however, I’m a bit concerned about the performance drop if I implement an RNN using RNNCell instead of RNN, both in terms of speed (see here) and classification accuracy (see here).

Any suggestions/advice would be greatly appreciated!


(John Smith) #2

When you run a rnn net over a sequence of data, the output has two tensors: (output, h_n). While h_n is the hidden state from the last timestamp, the tensor output actually has the hidden states for every timestamps in the sequence.

rnn_net = nn.RNN(4, 3)
X = torch.rand((6,5,4))
out, hidden = rnn_net(X)
out.shape
hidden.shape

You will see the shape of out is torch.Size([6, 5, 3]) which represents (seq_len, batch, hidden_size)


(Moonlightlane) #3

Thanks for pointing that out! Looks like I have been very sloppy when reading documentation…