I’m training a vanilla RNN on MNIST (sequential, len=784, one scalar value at each time step), and I would like to visualize the hidden states at every time step, not just the final time step. How can I access the hidden states prior to the one at the final time step? It seems that torch.nn.RNN
api only allows me to access the hidden state at the final time step.
I can probably accomplish this by using torch.nn.RNNCell
instead of torch.nn.RNN
: looping through the input sequence manually and saving all the hidden states; however, I’m a bit concerned about the performance drop if I implement an RNN using RNNCell
instead of RNN
, both in terms of speed (see here) and classification accuracy (see here).
Any suggestions/advice would be greatly appreciated!