I have a trained model, in which I want to be able to inspect intermediate activations and gates. It’s a language model, 2 layer LSTM preceded by an embedding layer and followed by a soft max layer. Currently I am looking at activation values and gradients through hooks.
The problem is, I want to be able to see the internal (hidden) activations as well as the internal gates. These components are not exposed to hooks. For example:
import torch import torch.nn as nn from torch.autograd import Variable def dummy_hook(module, input, output): print("\n module", module) print("input", len(input), input.size(), input.size(), input.size()) print("output", len(output), output.size(), output.size(), output.size()) print("=====") rnn = nn.LSTM(3, 5, num_layers=2) rnn.register_forward_hook(dummy_hook) inputs = (Variable(torch.randn(1, 3)) for _ in range(5)) # make a sequence of length 5 # initialize the hidden state. hidden = (Variable(torch.randn(2, 1, 5)), Variable(torch.randn(2, 1, 5))) for i in inputs: # Step through the sequence one element at a time. # after each step, hidden contains the hidden state. out, hidden = rnn(i.view(1, 1, -1), hidden)
this gives the output:
module LSTM(3, 5, num_layers=2) input 2 torch.Size([1, 1, 3]) torch.Size([2, 1, 5]) torch.Size([2, 1, 5]) output 2 torch.Size([1, 1, 5]) torch.Size([2, 1, 5]) torch.Size([2, 1, 5]) ===== ... etc
In my actual code, I am looking at the outputs themselves and using them. Unfortunately, none of the internals are exposed, it seems. I am open to explicitly rerunning the internal operations of the LSTM piece by piece if necessary, though I would rather not. If there is no better way, I still don’t actually know which of these input and output elements correspond to which state, and I don’t know how to run the internal parameters explicitly (as exposed by, e.g.,
rnn.weight_hh_l0) so as to get identical results. The LSTM module isn’t so readable because the particular use of these parameters is spread out all over the
nn.RNN module code.
What are my best options for inspecting the state of the gate and the internal state (e.g., state carried from the lower layer to the upper layer)?