As the question states, I have loaded the pretrained Resnet101 (model = models.resnet50(pretrained=True)) model in pytorch and would like to know how to selectively modify the weights of layers and test the model.
Lets say for simplicity that there are only 5 bottlenecks b1,b2,b3,b4,b5 in the model followed by one FC layer fc1. I would like to keep the weights for the layers in b1 (first bottleneck) while setting the weights of every layer in the following bottleneck after that to 0 so I can see how it performs just using the b1 weights.
Here is a good visualization of the ResNet architecture: Resnet50
And here is what b1 would look like starting at the pool1 layer all the way up to res2a:
You can do something like this to target individual modules:
import torchvision.models as models
resnet = models.resnet101(pretrained=True)
for (name, layer) in resnet._modules.items():
#iteration over outer layers
print((name, layer))
resnet._modules['layer1'][0]._modules['bn1'].weight.data.zero_()
The ._modules exposes layers as an ordered dict (it’s also private so maybe this is not future-proof? but the PT devs have referenced it on other threads) so you can index by key normally . Hence, printing them in the loop above helps to navigate complex models like resnet. it’s a little confusing becuase the nn.Sequential (keyed as ‘layer1’) is indexed like a list with integers.
You can use nested versions of that loop to get at things if you are changing a lot of stuff. If you are just changing a few, might actually be less confusing to do them in a one-liner individually. You can also use .requires_grad = False for individual layers if you just don’t want them to be updated during training.
And, of course, you could also just delete the module:
del(resnet._modules['layer1'][0]._modules['bn1'])
(edit)
You could also just duplicate the resnet.py file from pytorch/vision and make your own version of ResNet. I have had to do this for using larger/smaller images. It is sometimes easer for experimentation as you can save the model file with
your_resnet.__class__.__name__
in the filename and name it something like ‘ResNet_noBN5’. But whatever works for you.