I want to:
- Iterate through model parameters. Easily can be done:
- See what module type they belong to, e.g.,
type(param.MODULE) == nn.Conv2d
However, I’m not sure how to get the param’s module… does
model.named_parameters() return something I can use here to do this?
You could use the returned name to get the module. E.g this code would print the registered modules at the highest hierarchy level:
model = torchvision.models.resnet50()
for name, param in model.named_parameters():
print(getattr(model, name.split('.'))) # print "highest" module
I’m not sure this would print something like
nn.BatchNorm2d, but it suits my purposes, thanks.