I want to:
- Iterate through model parameters. Easily can be done:
model.parameters()
And:
- See what module type they belong to, e.g.,
type(param.MODULE) == nn.Conv2d
However, I’m not sure how to get the param’s module… does model.named_parameters()
return something I can use here to do this?
You could use the returned name to get the module. E.g this code would print the registered modules at the highest hierarchy level:
model = torchvision.models.resnet50()
for name, param in model.named_parameters():
print(name)
print(getattr(model, name.split('.')[0])) # print "highest" module
I’m not sure this would print something like nn.BatchNorm2d
, but it suits my purposes, thanks.