You could use the function apply()
to recursively apply a function to the network, and each sub-layer. Calling resnet.apply(weight_init_fun)
will apply the function weight_init_fun
on every sub-layer, so make it a function which takes a torch.nn.Module, checks compability and changes its weights.
According to the documentation (https://pytorch.org/docs/stable/nn.html), this is a typical use case for this function.