Inspecting lstm gates with hooks?

I have a trained model, in which I want to be able to inspect intermediate activations and gates. It’s a language model, 2 layer LSTM preceded by an embedding layer and followed by a soft max layer. Currently I am looking at activation values and gradients through hooks.

The problem is, I want to be able to see the internal (hidden) activations as well as the internal gates. These components are not exposed to hooks. For example:

import torch
import torch.nn as nn
from torch.autograd import Variable

def dummy_hook(module, input, output):
    print("\n module", module)
    print("input", len(input), input[0].size(), input[1][0].size(), input[1][1].size())
    print("output", len(output), output[0].size(), output[1][0].size(), output[1][1].size())
    print("=====")

rnn = nn.LSTM(3, 5, num_layers=2)

rnn.register_forward_hook(dummy_hook)

inputs = (Variable(torch.randn(1, 3)) for _ in range(5))  # make a sequence of length 5

# initialize the hidden state.
hidden = (Variable(torch.randn(2, 1, 5)), Variable(torch.randn(2, 1, 5)))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = rnn(i.view(1, 1, -1), hidden)

this gives the output:

 module LSTM(3, 5, num_layers=2)
input 2 torch.Size([1, 1, 3]) torch.Size([2, 1, 5]) torch.Size([2, 1, 5])
output 2 torch.Size([1, 1, 5]) torch.Size([2, 1, 5]) torch.Size([2, 1, 5])
=====
...
etc

In my actual code, I am looking at the outputs themselves and using them. Unfortunately, none of the internals are exposed, it seems. I am open to explicitly rerunning the internal operations of the LSTM piece by piece if necessary, though I would rather not. If there is no better way, I still don’t actually know which of these input and output elements correspond to which state, and I don’t know how to run the internal parameters explicitly (as exposed by, e.g., rnn.weight_hh_l0) so as to get identical results. The LSTM module isn’t so readable because the particular use of these parameters is spread out all over the nn.RNN module code.

What are my best options for inspecting the state of the gate and the internal state (e.g., state carried from the lower layer to the upper layer)?

Thank you for posting this question 8 months before I even knew how to ask it. At the moment, I am just trying to demonstrate the internals of the LSTM to students, to encourage them to poke and prod and thus learn what is going on inside.

Did you ever learn the answer to your question?

I would suggest that you build your own LSTM cell for such experiments.
You could use the cell we use in benchmarking as a start and then build your own nn.Module based on the function.

Best regards

Thomas

Thank you Tom!

Does “build your own nn.Module” mean more than just something like

def my_LSTM(nn.Module):
    def __init__(self,...):
        ...
    def forward(self,...):
        ...

?

In words, I would subclass nn.Module and then define an appropriate init() and forward()?

I think that is exactly what you mean, but I want to confirm it, for myself and any others who may arrive here later.

Thanks again,

Bill

Yes.

Best regards

Thomas

I want to know that will the implentation with benchmark be much slower than nn.LSTM?