Change attribute in forward gives different (and incorrect) results when using dataparallel

I am facing an issue similar as this one:
[Pytorch incorrect value of member variable when using Multi-gpu - Stack Overflow](member variable for multi-GPU)

One simplest example for this issue is as following:

## debug.py
import torch
import torch.nn as nn

class Conv2d(nn.Conv2d):
  def forward(self, input):
    self.foo = torch.ones(input.shape[0])
    print('A: ', self.foo.size())

def main():
  m = Conv2d(4, 3, 2)
  m = nn.DataParallel(m).cuda()

  m.module(torch.ones(1, 4, 6, 6).cuda())   # this step is necessary for m to have 'foo' as a member

  for bs in range(2, 5): 
    m(torch.ones(bs, 4, 6, 6).cuda())
    print('B: ', m.module.foo.size())

if __name__ == '__main__':
  main()

If I print at ‘A’, the result is always correct as every time I forward with some input. However, if I want to get the attribute outside somewhere, it will be correct only if I do use a single GPU. For example, running the above code with a single GPU gives

A: torch.Size([1])
A: torch.Size([2])
B: torch.Size([2])
A: torch.Size([3])
B: torch.Size([3])
A: torch.Size([4])
B: torch.Size([4])

while if I run it with multiple GPUs, it will give some different and incorrect result like

A: torch.Size([1])
A: torch.Size([1])
A: torch.Size([1])
B: torch.Size([1])
A: torch.Size([2])
A: torch.Size([1])
B: torch.Size([1])
A: torch.Size([2])
A: torch.Size([2])
B: torch.Size([1])

I think this is related to some synchronization problem but I wonder the detailed reason for this, and the correct way to set member variables to the module.

I am using Python 3.6.9, PyTorch 1.3.0, CUDA 10.0.130.

Thanks a lot.

nn.DataParallel will split the input batch in dim0 and send each chunk in the shize [batch_size//nb_gpus] to each specified device.
Each device will get a model replica, which will synchronize the parameters and buffers, but not the tensors. The self.foo tensor will thus only be updated in the replica temporarily and not reduced to the original model.
Assuming you are using 2 GPUs the output corresponds to:

A: torch.Size([1]) # m.module(torch.ones(1, 4, 6, 6).cuda())

# DataParallel call with bs=2
A: torch.Size([1]) # device0
A: torch.Size([1]) # device1
B: torch.Size([1]) # original model

# DataParallel call with bs=3
A: torch.Size([2]) # device0
A: torch.Size([1]) # device1
B: torch.Size([1]) # original model

# DataParallel call with bs=4
A: torch.Size([2]) # device0
A: torch.Size([2]) # device1
B: torch.Size([1]) # original model
2 Likes

Thanks a lot for your kind and quick reply. So do I have to register a buffer for it if I want to implement what I desired, or it is not possible to implement? Thanks a lot again.

You could implement a buffer, but I don’t think that manipulating it in the forward would reflect the changes in the original model.module, since it would be unclear which value and shape (in the second iteration) to store.

Could you explain your use case a bit? Maybe there is another way to achieve, what you need.

I am trying to implement an input adaptive model (aka dynamic model), like this one: [https://arxiv.org/abs/2003.10401](dynamic model), where each input determine the structure of the model. I need to get the computational cost corresponding to that specific input for optimization. The computation cost corresponds to foo here.

Unfortunately, I don’t think that this is easily doable in the current nn.DataParallel approach, as to the best of my knowledge, buffers won’t be synced (such as the running stats in batchnorm layers) by default.
You could try to either use the functional parallel calls and scatter and gather the buffer manually or alternatively, you could try to use DistributedDataParallel and reuse the code of SyncBatchNorm.

Thanks a lot for your kind reply. I will try that. If I implement it, I will post the basic example for future reference.

Hi ptrblck, I still do not think using buffer is a solution, unless I have the batch size in advance, and use it to declare the size of this buffer. After checking the code for the DataParallel module, and the associated methods of scatter, gather, replicate and parallel_apply, it seems to me that DataParallel will reference to the same module, so the attribute will be the same for different devices, and there can not be the same number of such attribute as devices. So the only reasonable solution is using buffer, but will cause the problem I mentioned above. Another problem is I can calculate gradient of this buffer with other parameters. I think another solution is to make the attribute I desired as output, but although simple this is not a very good solution. I wonder if there is other points that I missed and better solution to this. Thanks.

I solved it by using multiple output, with one of them being the attribute I desired. It works, although a bit ugly and is not the elegant solution I expected. I use dictionary for output to indicate which is the original output tensor, and which is the attribute. I am not sure yet if the gradient will work correctly but I did not find reason why not. Using tuple or list might not be possible as I only use these for those layers that I need to get the attribute for, such as Conv, and some other interface layer, such as Sequential, and for other layers I need to use the traditional behavior (otherwise I can reimplement all layer types I used and all have output with the desired attribute using some default value). Anyway, this is one possible solution, but probably not the best.