Traverse a model to get each module names and its parents

I would like to connect each nn.Module in a model to its named parent.

class AddBlock(nn.Module):
    def forward(self, x, y):
        return x+y

class multi_inp(nn.Module):
    def __init__(self):
        self.conv = nn.Conv2d(3,32,kernel_size=3)
        self.add = AddBlock()
    def forward(self, x, y):
        return self.conv(self.add(x,y))

a = torch.rand(3,128,128)
b = torch.rand(3,128,128)
model = multi_inp()

using the flowing:

 for n,m in model.named_modules():
     if n:


conv Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
add AddBlock()

I would also like to get another column stating a list of named modules for each node

module-name | parents | module
=========== | ======= | ======
conv        |  None   |  Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
add         | [conv]  |  AddBlock()
1 Like

β€œβ€¦I would also like to get another column…” - just for clarity, could you type the full output (for the code above) you are expecting?

thanks, I’ve updated the question

Please see if this helps:

level = 0
strParent = β€˜β€™
lstModules = list( model.named_modules())[1:]

#Initial Print Statements
print(’{0:<20}’.format(β€˜Module Name’), β€˜|’, β€˜{0:<10}’.format(β€˜Parent’) , β€˜|’, β€˜{0:<10}’.format(β€˜Module’, ’ β€˜*10))
print(’{0:<20}’.format(’===========’), β€˜|’, β€˜{0:<10}’.format(’======’), β€˜|’, β€˜{0:<10}’.format(’======’))

#Loop through the module
for i in range(len(lstModules)):
if level==0:

module_name = lstModules[i][0]
parent_name = strParent
module = str(lstModules[i][1])
print(’{0:<20}’.format(module_name), β€˜|’ , β€˜{0:<10}’.format(parent_name), β€˜|’, β€˜{0:<100}’.format(module))
strParent = lstModules[i][0]

thanks @KarthikR.
your snippet works fine for hierarchical architectures where all children are structured properly below their parents.
I am looking for a way to list all parents that would work also for complex scenarios such as resnets, where the connection scheme have multiple parents in non-sequential order.

import torchvision.models as models
model = models.resnet18()

I think this would require the inspection of the actual tensor flow along the net.