I am wondering if there is anything special about the forward method which DDP must use. Sometimes it is convenient to have a model call different methods during training for different tasks to avoid bundling everything up in the forward method with more complicated logic. For example, say I have the following module,
class Model(nn.Module):
def forward_pretrain(self, x):
...
def forward_finetune(self, x):
...
Is it ok to wrap this in DistributedDataParallel
and call model.module.forward_pretrain(x)
instead of bundling everything into the forward function? am I ruining the functionality of DDP in any way by calling the underlying module?