Calling all_reduce in forward of Module

I have a question regarding the use of collective functions such as all_reduce().

I am working on a layer in which I would like to synchronize a value across processes. I have seen an implementation of synchronized batch norm that essentially does what I am looking for. In those layers, it seems that all_reduce is called from forward, but there is also an autograd function that defines backward behavior as well.

Is that approach necessary for a Module that does not have trainable parameters? Can I just call all_reduce in the forward method of a Module or do I need to define it in an autograd function?

Btw, the layer I’m working on looks like this:

class BatchStandardDeviation(Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        batch_size, _, height, width = x.size()
        out = x - x.mean(dim=0, keepdim=True)  # Shape: B, C, H, W
        out = torch.sqrt(out.pow(2.0).mean(dim=0, keepdim=False) + 1e-8)  # Shape: 1, C, H, W
        out = out.mean().view(1, 1, 1, 1)
        out = out.repeat(batch_size, 1, height, width)  # Shape: B, 1, H, W
        return torch.cat([x, out], dim=1)

It concats mini-batch statistics to each feature map. I would like to get those statistics from batches across all processes rather than just the local batch in one processes.

You can call allreduce in the forward pass, but beware that if you have multiple of these layers, or other layers that need to call any torch.distributed functions, that the order they are called in needs to be identical across all workers. If you end up with any learnable parameters, consider the concerns I expressed on this PR adding sync batch norm.

Thanks for the reply! After reading over the PR for synced batch norm (really excited to see this functionality being baked into PyTorch btw, I think it’s a must have for proper distributed training), you seem to point out a potential deadlock when everything is running on the same process group. Is this a factor for my layer since no learnable parameters are used? I can always spin up a new process group for this layer to use if so.