I have a question regarding the use of collective functions such as all_reduce().
I am working on a layer in which I would like to synchronize a value across processes. I have seen an implementation of synchronized batch norm that essentially does what I am looking for. In those layers, it seems that all_reduce is called from forward, but there is also an autograd function that defines backward behavior as well.
Is that approach necessary for a Module that does not have trainable parameters? Can I just call all_reduce in the forward method of a Module or do I need to define it in an autograd function?
Btw, the layer I’m working on looks like this:
class BatchStandardDeviation(Module):
def __init__(self):
super().__init__()
def forward(self, x):
batch_size, _, height, width = x.size()
out = x - x.mean(dim=0, keepdim=True) # Shape: B, C, H, W
out = torch.sqrt(out.pow(2.0).mean(dim=0, keepdim=False) + 1e-8) # Shape: 1, C, H, W
out = out.mean().view(1, 1, 1, 1)
out = out.repeat(batch_size, 1, height, width) # Shape: B, 1, H, W
return torch.cat([x, out], dim=1)
It concats mini-batch statistics to each feature map. I would like to get those statistics from batches across all processes rather than just the local batch in one processes.