How to access each layer of torchvision ResNet

You could use the function apply() to recursively apply a function to the network, and each sub-layer. Calling resnet.apply(weight_init_fun) will apply the function weight_init_fun on every sub-layer, so make it a function which takes a torch.nn.Module, checks compability and changes its weights.

According to the documentation (https://pytorch.org/docs/stable/nn.html), this is a typical use case for this function.

1 Like