Unexpected outputs for forward hooks

To get a bit more familiar with hooks (just forward hooks for now), I’ve created a very simple LSTM model:

class SimpleLSTM(nn.Module):
    
    def __init__(self, vocab_size, out_size):
        super().__init__()
        
        self.embed = nn.Embedding(vocab_size, 5)
        self.lstm = nn.LSTM(5, 4, bidirectional=True)
        self.linear1 = nn.Linear(4*2, 4)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.linear2 = nn.Linear(4, out_size)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, X):
        batch_size = X.shape[0]
        out = self.embed(X)
        out = out.transpose(0, 1)
        out, (h, c) = self.lstm(out)
        h = h.view(1, 2, batch_size, 4)[-1]
        h = torch.concat((h[0], h[1]), dim=1)
        out = self.linear1(h)
        out = self.relu1(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = self.log_softmax(out)
        return(out)

I’ve then attached forward hooks to each layer, with the sole purpose to print the input shapes of the output shapes. When I run a sample batch of (3,5) – meaning (batch_size, seq_len), I get the following outputs:

------- Embedding(10, 5) -------
embed [ in] torch.Size([3, 5])
embed [out] torch.Size([5, 5])

------- LSTM(5, 4, bidirectional=True) -------
lstm [ in] torch.Size([5, 3, 5])
lstm [out] torch.Size([5, 3, 8])

------- Linear(in_features=8, out_features=4, bias=True) -------
linear1 [ in] torch.Size([3, 8])
linear1 [out] torch.Size([4])

------- ReLU() -------
relu1 [ in] torch.Size([3, 4])
relu1 [out] torch.Size([4])

------- Dropout(p=0.5, inplace=False) -------
dropout1 [ in] torch.Size([3, 4])
dropout1 [out] torch.Size([4])

------- Linear(in_features=4, out_features=2, bias=True) -------
linear2 [ in] torch.Size([3, 4])
linear2 [out] torch.Size([2])

------- LogSoftmax(dim=1) -------
log_softmax [ in] torch.Size([3, 2])
log_softmax [out] torch.Size([2])

It looks like the output shape is always missing the batch dimensions. The only exception is the lstm layer; I assume because batch_first=False.

Why is the output of a forward hook not the complete tensor? Or what might I do wrong here?

I’m not sure how your forward hook looks like, but I would guess you might be indexing into input and output in a way which might drop the batch dimension.
I’m using a small and naive helper util. to flatten the potentially nested input/output tuples and get this output using your code:

class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, out_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 5)
        self.lstm = nn.LSTM(5, 4, bidirectional=True)
        self.linear1 = nn.Linear(4*2, 4)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.linear2 = nn.Linear(4, out_size)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, X):
        batch_size = X.shape[0]
        out = self.embed(X)
        out = out.transpose(0, 1)
        out, (h, c) = self.lstm(out)
        h = h.view(1, 2, batch_size, 4)[-1]
        h = torch.concat((h[0], h[1]), dim=1)
        out = self.linear1(h)
        out = self.relu1(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = self.log_softmax(out)
        return(out)

def flatten(tensors):
    if isinstance(tensors, tuple):
        if len(tensors) == 0:
            return ()
        else:
            return flatten(tensors[0]) + flatten(tensors[1:])
    else:
        return (tensors,)

def get_hook(name):
    def hook(m, input, output):
        print(f"layer {name}")
        for i in flatten(input):
            print(f"input shape {i.shape}")
        for o in flatten(output):
            print(f"output shape {o.shape}")
    return hook
    

model = SimpleLSTM(vocab_size=10, out_size=20)
for name, module in model.named_modules():
    module.register_forward_hook(get_hook(name))

x = torch.randint(0, 10, (3, 5))
out = model(x)
# layer embed
# input shape torch.Size([3, 5])
# output shape torch.Size([3, 5, 5])
# layer lstm
# input shape torch.Size([5, 3, 5])
# output shape torch.Size([5, 3, 8])
# output shape torch.Size([2, 3, 4])
# output shape torch.Size([2, 3, 4])
# layer linear1
# input shape torch.Size([3, 8])
# output shape torch.Size([3, 4])
# layer relu1
# input shape torch.Size([3, 4])
# output shape torch.Size([3, 4])
# layer dropout1
# input shape torch.Size([3, 4])
# output shape torch.Size([3, 4])
# layer linear2
# input shape torch.Size([3, 4])
# output shape torch.Size([3, 20])
# layer log_softmax
# input shape torch.Size([3, 20])
# output shape torch.Size([3, 20])
# layer 
# input shape torch.Size([3, 5])
# output shape torch.Size([3, 20])

which should show all batch dimensions.

1 Like

Right, both inputs and outputs come in different types. For the LSTM output it should have been obvious but I missed that as well.

It’s still feels a bit unintuitive that the inputs are always tuples, while the outputs are tensors for nn.Linear, nn.Embedding, etc. but tuples for nn.LSTM (again, this one makes sense of course):

------- Embedding(10, 5) -------
<class 'tuple'>
<class 'torch.Tensor'>

------- LSTM(5, 4, bidirectional=True) -------
<class 'tuple'>
<class 'tuple'>

------- Linear(in_features=8, out_features=4, bias=True) -------
<class 'tuple'>
<class 'torch.Tensor'>

------- ReLU() -------
<class 'tuple'>
<class 'torch.Tensor'>

------- Dropout(p=0.5, inplace=False) -------
<class 'tuple'>
<class 'torch.Tensor'>

------- Linear(in_features=4, out_features=2, bias=True) -------
<class 'tuple'>
<class 'torch.Tensor'>

------- LogSoftmax(dim=1) -------
<class 'tuple'>
<class 'torch.Tensor'>

But yes, since thought it’s always a tuple, I used

print(input[0])
print(output[0])

which is fine for the inputs (always a tuple), but only prints the first sample in the batch if the output is a tensor.