I’m trying to figure out how I can periodically re-initialize certain layers of my network (for example the last m layers of a CNN, such as a ResNet) when using DataDistributedParallel.
In the non-distributed setting, I think I could achieve this by iterating over the modules. For example:
for i, module in iter(algorithm.model.modules()):
if i >= m:
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="leaky_relu"
)
elif isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight.data)
module.bias.data.zero_()
However, when I try to integrate this into my code, which uses DataDistributedParallel
, iterating over the modules raises the error:
TypeError: cannot unpack non-iterable DistributedDataParallel object
Looking more closely at the iterable returned by model.modules()
when the model is wrapped by DataDistributedParallel
I can see that it’s added more complexity.
Has anyone tried what I’m setting out to achieve? If so, have you got any tips on how to best achieve this?
Thanks in advance!