How to synchronize a variable across entire batch in ddp

I am using ddp to train my model now. In my model I want to calculate the standard deviation across the batch. However since I am wondering if i can calculate the standard deviation across the entire batch instead of within each device. The standard deviation will be part of my computation graph.

I feel that this is similar to synchronize batchnorm and should be doable. How would I go about doing this? Here is an example of what I want to do

def forward(self, input):
        feats = conv(input)
        batch, channel, height, width = feats.shape
        stddev = feats.view(batch, -1, 1, channel, height, width)
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(batch, 1, height, width)
        feats = torch.cat([feats, stddev], 1)
        output = conv_last(feats)
        return output

Basically when I compute stddev, I want it to do it over entire batch.

Hey @hij

Buffers are broadcast from rank 0 to other processes in the beginning of every forward pass. See the code below:

If this is what you need, you can register the stddev as a buffer. If you need sth different (e.g., square stddev and sum it across all processes over the wire), you can call allreduce or allgather in the forward function to do that.

1 Like

If i understand you correctly, registering it as a buffer only allows the stddev of rank 0 to be distributed to other processes. What I want to do is allow all tensors in all processes to contribute to calculating stddev.

1 Like

Yes.

What I want to do is allow all tensors in all processes to contribute to calculating stddev.

You can do this by using the collective communication APIs (allreduce/allgather) in the forward pass. One caveat here is that, the collective communication API requires all processes in the same group to invoke the same API in the same order, otherwise, it might hang or the result might be wrong. If you are not sure whether the allreduce/allgather for stddev would interleave with other collectives, you can use the new_group API to create a dedicated process group for collecting/summing stddev.

I have read that all_gather do not retain the gradient information. For my application, I want stddev to be part of the computation graph. How would I go about doing this? And could you also point me to an example/tutorial for these usage. I have not done any distributed training so I am not sure how to use these functions.

I am trying to understand the use case here:

I have read that all_gather do not retain the gradient information. For my application, I want stddev to be part of the computation graph. How would I go about doing this?

IIUC, stddev in this case is an intermediate output in forward and it’s not a model parameter? So you need its gradient during the backward pass to compute parameter gradients, but you don’t need to retain its gradient for optimizer step()? If above is true, why existing parameter gradient synchronization in DDP not sufficient?

And could you also point me to an example/tutorial for these usage.

Sure, below is the tutorial, and please search for “All-Reduce example.”

https://pytorch.org/tutorials/intermediate/dist_tuto.html

I might have been confused. So this code snippet is part of my model which I intend to train. However, since I have limited gpu memory, I can only train with batch size of 1. And there isn’t a point to calculate stddev over a batchsize of 1 even if I do DDP. So in this case, would the allgather allreduce work?

I see. I am not sure if the result would still be correct in this case even if allgather and allreduce can retain gradients. IIUC, if this is trained without DDP (assume there are large enough GPU memory), then both feats and stddev are calculated based on all inputs. When trained with DDP, feats are now only derived from local inputs, and you would like to have stddev to be based on global inputs. So, when you cat feats and stddev, the output of the forward now represents a different thing. I am not sure if the loss function can handle that. Even if it can, what does averaging gradient mean in this case?

If above (local feats + global stddev) is the expected behavior, there might be a few ways to implement this.

  1. Implement a custom autograd function. E.g., its forward function can use an allgather to collect stddev from all processes, and its backward function can use another allgather to collect gradients and then extract the part belongs to the local stddev and sum them up.

  2. Use torch.distributed.rpc. There can be a master and a few workers, where each worker calculates the feats and stddev for its own input, and then the master gathers all feats and stddev to compute the final loss. Some more tutorials for this:
    a. Getting Started with Distributed RPC Framework — PyTorch Tutorials 2.1.1+cu121 documentation
    b. Implementing a Parameter Server Using Distributed RPC Framework — PyTorch Tutorials 2.1.1+cu121 documentation

In this case, since there is no batch dependencies in feat, would it be different if it is local feat + global stddev vs global feat + global stddev? Since stddev is concatenated to each feat tensor separately, they should have the same effect?

I will try your suggestions. Thank you!