Saving inputs of all `nn.Module` children

Hi,

I have this problem: I would like to save the inputs of all nn.Module in my top-level module. For example, let’s assume I have this very simple model and that – given a batch – I want to save the inputs to conv1, conv2, fc1 and fc2:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(...)
        self.conv2 = nn.Conv2d(...)
        self.fc1 = nn.Linear(...)
        self.fc2 = nn.Linear(...)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(-1, ...)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

I can certainly do some like this but it’s very ugly and not very flexible.

    def forward(self, x):
        self.input_to_conv1 = x   # <====
        x = self.conv1(x)
        x = F.relu(x)
        self.input_to_conv2 = x   # <====
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(-1, ...)
        self.input_to_fc1 = x     # <====
        x = self.fc1(x)
        x = F.relu(x)
        self.input_to_fc21 = x    # <====
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Moreover, imagine that I don’t have access to how this modules are built: I know that they are made of nn.Modules but I have no idea of the names and/or structure of the model.

Therefore the question is: Is there a “pytorch-ish” way of saving the inputs of all nn.Module children for later use?

I think you want to use the forward_hook for this. You can register a hook so that at every forward call, the registered hooks will call a function where you can save their output. https://pytorch.org/docs/stable/nn.html?highlight=forward_hook#torch.nn.Module.register_forward_hook

A working example here:

import torch
import torch.nn as nn
import torch.nn.functional as F


def print_tensor_props(self, input, output):
    print(input[0].shape, end=' => ')
    print(output.shape)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 3, padding=1)
        self.conv2 = nn.Conv2d(4, 8, 3, padding=1)
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(-1, 512)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

net = Net()
net.conv1.register_forward_hook(print_tensor_props)
net.conv2.register_forward_hook(print_tensor_props)
net.fc1.register_forward_hook(print_tensor_props)
net.fc2.register_forward_hook(print_tensor_props)

x = torch.randn(5, 3, 8, 8)
print('x  ::  ', x.shape)
out = net(x)

which will print the following:

x  ::   torch.Size([5, 3, 8, 8])
torch.Size([5, 3, 8, 8]) => torch.Size([5, 4, 8, 8])
torch.Size([5, 4, 8, 8]) => torch.Size([5, 8, 8, 8])
torch.Size([5, 512]) => torch.Size([5, 256])
torch.Size([5, 256]) => torch.Size([5, 10])

Furthermore, if you want to save the inputs to a list, you can do as follows:

g_list = []

def save_tensor(self, input, output):
    g_list.append(input)

net = Net()
net.conv1.register_forward_hook(save_tensor)
net.conv2.register_forward_hook(save_tensor)
net.fc1.register_forward_hook(save_tensor)
net.fc2.register_forward_hook(save_tensor)

x = torch.randn(5, 3, 8, 8)
out = net(x)
for i,x in enumerate(g_list):
    print("Input to layer {} has shape ".format(i), x[0].shape)

Input to layer 0 has shape  torch.Size([5, 3, 8, 8])
Input to layer 1 has shape  torch.Size([5, 4, 8, 8])
Input to layer 2 has shape  torch.Size([5, 512])
Input to layer 3 has shape  torch.Size([5, 256])

Finally, the other issue as you said if do not know the names, or for some reason, we don’t want to define the hook one at a time. So, the solution to this would be to use net.children() which gives an iterator over the layers in net:

>>> net.children()
<generator object Module.children at 0x7f411dad15e8>
>>> for layer in net.children():
...    print(layer)

Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Linear(in_features=512, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)

So, we can define the hook for each layer in a loop (print_tensor_props is defined in the previous post):

>>> for layer in net.children():
...    layer.register_forward_hook(print_tensor_props)
... 
<torch.utils.hooks.RemovableHandle object at 0x7f4163193208>
<torch.utils.hooks.RemovableHandle object at 0x7f416140e7b8>
<torch.utils.hooks.RemovableHandle object at 0x7f4163193208>
<torch.utils.hooks.RemovableHandle object at 0x7f416140e7b8>

## testing: 
>>> out = net(x)
torch.Size([5, 3, 8, 8]) => torch.Size([5, 4, 8, 8])
torch.Size([5, 4, 8, 8]) => torch.Size([5, 8, 8, 8])
torch.Size([5, 512]) => torch.Size([5, 256])
torch.Size([5, 256]) => torch.Size([5, 10])

1 Like