You can use forward hook in this way:
import torch
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.m1 = nn.Conv2d(1, 2, 3)
self.m2 = nn.BatchNorm2d(2)
self.m3 = nn.ReLU()
self.m4 = nn.Conv2d(2, 3, 3)
def forward(self, x):
x = self.m1(x)
x = self.m2(x)
x = self.m3(x)
x = self.m4(x)
return x
modules = []
def add_hook(m):
def forward_hook(module, input, output):
modules.append(module)
m.register_forward_hook(forward_hook)
foo = Foo()
foo.apply(add_hook) # function `add_hook` is applied to the every submodule including self.
input = torch.rand(1, 1, 10, 10)
foo(input) # hooks are fired sequentially from model input to the output
print(modules)
which prints out:
[Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)),
BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
ReLU(),
Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1)),
Foo(
(m1): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
(m2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(m3): ReLU()
(m4): Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1))
)]