Get layers in order of the data flow from input to output

Hi -

Is there a way to get the layers in the order of the data flow? I need to change the size of the Conv2d input and output channels one by one, which means the number of input channels should be set to the number of the output channels from the previous convolution. I tried model.children() but it doesn’t work.

You can use forward hook in this way:

import torch
import torch.nn as nn


class Foo(nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.m1 = nn.Conv2d(1, 2, 3)
        self.m2 = nn.BatchNorm2d(2)
        self.m3 = nn.ReLU()
        self.m4 = nn.Conv2d(2, 3, 3)
        
    def forward(self, x):
        x = self.m1(x)
        x = self.m2(x)
        x = self.m3(x)
        x = self.m4(x)
        
        return x


modules = []
def add_hook(m):
    def forward_hook(module, input, output):
        modules.append(module)
    m.register_forward_hook(forward_hook)


foo = Foo()
foo.apply(add_hook)  # function `add_hook` is applied to the every submodule including self.

input = torch.rand(1, 1, 10, 10)
foo(input)  # hooks are fired sequentially from model input to the output
print(modules)

which prints out:

[Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)), 
 BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 
 ReLU(), 
 Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1)), 
 Foo(
  (m1): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
  (m2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (m3): ReLU()
  (m4): Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1))
)]

Also note that model.modules() or model.children() do not guarantee that we can get resident modules sequentially from model input to the output - they print out modules in an order of registration.

Thanks, I’m also looking at torch.nn.Sequential(*list(model.children())) but I’m not sure if it works with residual connections.

For a residual connection, I’d like to recommend writing a customized class. See official torch code here: vision/resnet.py at master · pytorch/vision · GitHub.