Is it ok to use methods other than 'forward' in DDP?

I think, if you have a DDP model A, and:

class A(nn.Module):
    ...

model = A()
model = DDP(model)

Then you must use output = model(input) at least once, no matter what in this forward function. OR it may cause some problems.
Like I have subsequent code:

# A ResNet18 demo for MNIST:
class ResNet18(nn.Module):
    def __init__(self, num_classes: int):
        super(ResNet18, self).__init__()

        if is_main_process():
            self.model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
        else:
            self.model = torchvision.models.resnet18(weights=None)

        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)

        return

    def forward(self, x):
        # return self.model(x)
        return x.detach()    # this forward function does not process any computing

    def fake_forward(self, x):
        return self.model(x)

def build(config: dict):
    return ResNet18(
        num_classes=10
    )

And in my training function, if I:

outputs = model.module.fake_forward(images.to(device))

The parameters of my model will be inconsistent between different GPUs (nodes).

But If I:

images = model(images)    # although do noting!
# or images = model.forward(images)
outputs = model.module.fake_forward(images.to(device))

The parameters will be consistent.

Thus, I think some code deep in the forward() function is the key to DDP parameter sync.

1 Like