Parameter created with register_buffer is not updated with MultipleGPUs

I am creating my own module extending Conv2d. Inside the constructor I define my own “my_param” module, which I use to store some information on the weights during forward. Then I pass this module tothe F.Conv2d function, in place of self.weight as in classic convolution.

class MyConv(nn.Conv2d):
    def __init__(...):
...
         self.register_buffer("my_param", torch.zeros(self.weight.shape))
    

    def forward(self, input):
        self.my_param = special_op(self.weight) 

        output = F.conv2d(input, self.my_param, self.bias,
                          self.stride, self.padding,
                          self.dilation, self.groups)
        return output
    

When using this class with 1 GPU (even with DataParallel), self.my_param correctly stores the ouput of special_op. Instead, with multiple GPUs, self.my_param remains a tensor of zeros.

Is there any workaround? What am I missing?

As you observe, by default, the buffers are local to the GPU.

If you wanted to synchronize your buffer across GPUs, you need to provide the structure to do that, which is a bit tricky in general.

SyncBatchNorm might be some inspiration (note that it also imports a SyncBatchNorm Function).

Best regards

Thomas