Is it possible to iterate through all model parameters AND see which module type they belong to?

I want to:

  1. Iterate through model parameters. Easily can be done: model.parameters()


  1. See what module type they belong to, e.g., type(param.MODULE) == nn.Conv2d

However, I’m not sure how to get the param’s module… does model.named_parameters() return something I can use here to do this?

You could use the returned name to get the module. E.g this code would print the registered modules at the highest hierarchy level:

model = torchvision.models.resnet50()
for name, param in model.named_parameters():
    print(getattr(model, name.split('.')[0])) # print "highest" module

I’m not sure this would print something like nn.BatchNorm2d, but it suits my purposes, thanks.