I am creating my own module extending Conv2d. Inside the constructor I define my own “my_param” module, which I use to store some information on the weights during forward. Then I pass this module tothe F.Conv2d function, in place of self.weight as in classic convolution.
class MyConv(nn.Conv2d):
def __init__(...):
...
self.register_buffer("my_param", torch.zeros(self.weight.shape))
def forward(self, input):
self.my_param = special_op(self.weight)
output = F.conv2d(input, self.my_param, self.bias,
self.stride, self.padding,
self.dilation, self.groups)
return output
When using this class with 1 GPU (even with DataParallel), self.my_param correctly stores the ouput of special_op. Instead, with multiple GPUs, self.my_param remains a tensor of zeros.
Is there any workaround? What am I missing?
1 Like
tom
(Thomas V)
October 20, 2021, 6:26pm
2
As you observe, by default, the buffers are local to the GPU.
If you wanted to synchronize your buffer across GPUs, you need to provide the structure to do that, which is a bit tricky in general.
SyncBatchNorm might be some inspiration (note that it also imports a SyncBatchNorm Function).
# If buffers are not to be tracked, ensure that they won't be updated
running_mean = (
self.running_mean if not self.training or self.track_running_stats else None
)
running_var = (
self.running_var if not self.training or self.track_running_stats else None
)
# Don't sync batchnorm stats in inference mode (model.eval()).
need_sync = (bn_training and self.training)
if need_sync:
process_group = torch.distributed.group.WORLD
if self.process_group:
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
need_sync = world_size > 1
# fallback to framework BN when synchronization is not necessary
if not need_sync:
return F.batch_norm(
input,
Best regards
Thomas