Is it ok to use methods other than 'forward' in DDP?

I am wondering if there is anything special about the forward method which DDP must use. Sometimes it is convenient to have a model call different methods during training for different tasks to avoid bundling everything up in the forward method with more complicated logic. For example, say I have the following module,

class Model(nn.Module):
  def forward_pretrain(self, x):
    ...

  def forward_finetune(self, x):
    ...

Is it ok to wrap this in DistributedDataParallel and call model.module.forward_pretrain(x) instead of bundling everything into the forward function? am I ruining the functionality of DDP in any way by calling the underlying module?

I would not use custom forward methods, as these would skip hooks (the same would happen if you manually call model.forward(x) instead of model(x)).
E.g. this code uses forward_pre_hooks to copy to lower precision and could easily break.
Also, here the self.module is called, which will internally call into __call__ and then forward. Using your custom forward methods would break this again.

However, it should be possible to call your custom forward methods from the actual forward method in case this would keep your code cleaner.

3 Likes

I think, if you have a DDP model A, and:

class A(nn.Module):
    ...

model = A()
model = DDP(model)

Then you must use output = model(input) at least once, no matter what in this forward function. OR it may cause some problems.
Like I have subsequent code:

# A ResNet18 demo for MNIST:
class ResNet18(nn.Module):
    def __init__(self, num_classes: int):
        super(ResNet18, self).__init__()

        if is_main_process():
            self.model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
        else:
            self.model = torchvision.models.resnet18(weights=None)

        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)

        return

    def forward(self, x):
        # return self.model(x)
        return x.detach()    # this forward function does not process any computing

    def fake_forward(self, x):
        return self.model(x)

def build(config: dict):
    return ResNet18(
        num_classes=10
    )

And in my training function, if I:

outputs = model.module.fake_forward(images.to(device))

The parameters of my model will be inconsistent between different GPUs (nodes).

But If I:

images = model(images)    # although do noting!
# or images = model.forward(images)
outputs = model.module.fake_forward(images.to(device))

The parameters will be consistent.

Thus, I think some code deep in the forward() function is the key to DDP parameter sync.

1 Like