I’m training a vanilla RNN on MNIST (sequential, len=784, one scalar value at each time step), and I would like to visualize the hidden states at every time step, not just the final time step. How can I access the hidden states prior to the one at the final time step? It seems that
torch.nn.RNN api only allows me to access the hidden state at the final time step.
I can probably accomplish this by using
torch.nn.RNNCell instead of
torch.nn.RNN: looping through the input sequence manually and saving all the hidden states; however, I’m a bit concerned about the performance drop if I implement an RNN using
RNNCell instead of
RNN, both in terms of speed (see here) and classification accuracy (see here).
Any suggestions/advice would be greatly appreciated!