I am creating my own module extending Conv2d. Inside the constructor I define my own “my_param” module, which I use to store some information on the weights during forward. Then I pass this module tothe F.Conv2d function, in place of self.weight as in classic convolution.
def forward(self, input):
self.my_param = special_op(self.weight)
output = F.conv2d(input, self.my_param, self.bias,
When using this class with 1 GPU (even with DataParallel), self.my_param correctly stores the ouput of special_op. Instead, with multiple GPUs, self.my_param remains a tensor of zeros.
Is there any workaround? What am I missing?
As you observe, by default, the buffers are local to the GPU.
If you wanted to synchronize your buffer across GPUs, you need to provide the structure to do that, which is a bit tricky in general.
SyncBatchNorm might be some inspiration (note that it also imports a SyncBatchNorm Function).
# If buffers are not to be tracked, ensure that they won't be updated
running_mean = (
self.running_mean if not self.training or self.track_running_stats else None
running_var = (
self.running_var if not self.training or self.track_running_stats else None
# Don't sync batchnorm stats in inference mode (model.eval()).
need_sync = (bn_training and self.training)
process_group = torch.distributed.group.WORLD
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
need_sync = world_size > 1
# fallback to framework BN when synchronization is not necessary
if not need_sync: