Using model functions in DistributedDataParallel

I’ve noticed something peculiar in DDP training, but cannot seem to find any related posts.

I have an ordinary ConvNet, whose forward() is defined in the following manner:

class MyModel(Module):
    ...
    def forward(self, x):
        mid = self.feature_extractor(x)
        return self.classifier(mid)

    def feature_extractor(self, x):
        # Some layers here

    def classifier(self, x):
        # Some layers here

In some cases, I want the mid-level features, and in others, I don’t want them at all. Also, sometimes I have mid-level features from elsewhere, and just want a forward pass through the classifier. Therefore, instead of adding a conditional to my forward method, I’ve split the forward into two sub-methods, so I can use the feature_extractor and classifier methods as needed.

It turns out, in DDP (maybe for DP as well?) there is a difference between training with:

  1. out = model(x), and
  2. mid = model.module.feature_extractor(x); out = model.module.classifier(mid)

In the case of (1), the model trains perfectly as expected. However, in the case of (2), the loss does not converge as efficiently as expected, and may not even converge to the same value as (1).

My question is: What would be the underlying mechanism that makes a difference between (1) and (2)? Could it be that (2) does not perform gradient synchronization across processes? Or perhaps nn.SyncBatchNorm is not working as expected?

Thanks

Yes, the second approach will not work with DDP since you are explicitly calling the underlying model placed on the current device. DDP will not sync the gradient as it depends on its __call__ method being invoked via a direct model(x) call.

2 Likes

I see, thanks for the reply. Would this apply to DataParallel as well (perhaps gradients are not reduced to gpu:0?)

Yes, nn.DataParallel also depends on its forward method as the data parallelization logic is used there as seen here.

1 Like