You could use the function apply() to recursively apply a function to the network, and each sub-layer. Calling resnet.apply(weight_init_fun) will apply the function weight_init_fun on every sub-layer, so make it a function which takes a torch.nn.Module, checks compability and changes its weights.