How to access each layer of torchvision ResNet

I am using ResNet18 model of torchvision.
I would like to access each layer of resnet and change its weights to something.

Currently I am using something like this:

for (name, layer) in resnet._modules.items():
# assign new weights to layer.weights.

However, this is ignoring the layer1, layer2 or block1, block2. We can recursively find the most basic layer and assign new weights.

But is there any simpler way to do it?

So is there any way to do that.

You could use the function apply() to recursively apply a function to the network, and each sub-layer. Calling resnet.apply(weight_init_fun) will apply the function weight_init_fun on every sub-layer, so make it a function which takes a torch.nn.Module, checks compability and changes its weights.

According to the documentation (https://pytorch.org/docs/stable/nn.html), this is a typical use case for this function.

1 Like