Is it ok to use methods other than 'forward' in DDP?

I am wondering if there is anything special about the forward method which DDP must use. Sometimes it is convenient to have a model call different methods during training for different tasks to avoid bundling everything up in the forward method with more complicated logic. For example, say I have the following module,

class Model(nn.Module):
  def forward_pretrain(self, x):
    ...

  def forward_finetune(self, x):
    ...

Is it ok to wrap this in DistributedDataParallel and call model.module.forward_pretrain(x) instead of bundling everything into the forward function? am I ruining the functionality of DDP in any way by calling the underlying module?

I would not use custom forward methods, as these would skip hooks (the same would happen if you manually call model.forward(x) instead of model(x)).
E.g. this code uses forward_pre_hooks to copy to lower precision and could easily break.
Also, here the self.module is called, which will internally call into __call__ and then forward. Using your custom forward methods would break this again.

However, it should be possible to call your custom forward methods from the actual forward method in case this would keep your code cleaner.

1 Like