Understanding behavior of torch.nn.DataParallel?

Hi everyone, I am trying to understand the behavior of torch.nn.DataParallel. The example code portion is given below for reference. Lets say I am using 8 batch size and two GPUs. Each GPU process 4 data samples. My questions are:

  1. While updating the running means for batch_normalization, does this module update the mean back to original model by considering the whole batch size (like 8 batch size) or only updates on a specific device? In other words, if GPU:0 estimates mean ‘a’ value for a batch of 4 data-samples and GPU:1 estimates another mean value (say ‘b’), does pytorch updates the batch_normalization of model by taking mean of both ‘a’ and ‘b’ or two devices update the batch normalization mean independent of each other? I looked at the source code and documentation and source code but did not understand.

  2. I saw in one of the comments where Soumith mentioned ‘if you notice the examples, DataParallel is not applied to the entire network + loss. It is only applied to part of the network.’ (How to use DataParallel in backward?). How does pytorch’s module nn.DataParallel decides which part of ‘network’ to send on GPU?

  3. The documentation says ‘The parallelized :attr:module must have its parameters and buffers on
    device_ids[0] before running this :class:~torch.nn.DataParallel module.’ I was copying the model ‘net’ to GPU device after applying nn.DataParallel on ‘net’ (as shown below) and the model trains fine. I do not understand why is it compulsory to send the model to GPU before applying nn.DataParallel (as given in source code)?

net = nn.DataParallel(net)
net = net.to(device)
if torch.cuda.is_available(): 
   net.cuda()
   softMax.cuda()
   CE_loss.cuda()
  
  1. nn.DataParallel would update the batchnorm stats on the default device, as seen in this code snippet:
bn = nn.BatchNorm2d(3, momentum=1.0).cuda()

a = torch.randn(16, 3, 224, 224) * 12 + 7
b = torch.randn(16, 3, 224, 224) * 5 - 9
x = torch.cat((a, b), dim=0)

model = nn.DataParallel(bn, device_ids=[0, 1])
out = model(x)
print(model.module.running_mean)
> tensor([6.9984, 7.0203, 7.0044], device='cuda:0')
print(torch.sqrt(model.module.running_var))
> tensor([11.9902, 12.0097, 11.9818], device='cuda:0')

If you are using nn.DistributedDataParallel (which would also be faster than nn.DataParallel), you could use SyncBatchNorm to synchronize these running stats.

  1. You are directly passing the module to nn.DataParallel, which should be executed on multiple devices. E.g. if you only want to pass a submodule to it, you could use:
model = MyModel()
model.submodule = nn.DataParallel(model.submodule)
  1. Transferring the parameters to the device after the nn.DataParallel creation should also work. I think the documentation mentions, that you would have to transfer the model to the default device at one point before executing the forward pass.
3 Likes

thanks, I understand now