How to perform all_reduce within DistributedDataParallel?

Hi, currently I’m having trouble performing all_reduce operations on nn.Module's internal buffers, in multi-node multi-gpu scenario using DistributedDataParallel:

E.g., I have a module named Foo, with registered buffer buf_bar. During training, I want to average it on all running instances. Say, I have 2 servers each equipped with 8 gpus; if I don’t use DistributedDataParallel and just launch 16 jobs assigning to each gpus, I can simply write:

rank = dist.get_rank()
world_size = dist.get_world_size()

class Foo(nn.Module):
  def __init__(...):
    # ...
    self.register_buffer("buf_bar", Tensor(*size).to(f"cuda:{rank % world_size}"))

  # ...
  def forward(x):
    # calculates `buf_bar` and div by `world_size`
    dist.all_reduce(self.buf_bar)
    # continue training using reduced `buf_bar`

Since each gpu has its unique rank, this works just fine.

However, if I wrap Foo with DistributedDataParallel and launch 2 jobs, each job sees 8 gpus, as suggested in official documents, then above code locks durning training, since the rank becomes node-wise and each tensor’s device should be indexed by two-level: rank and device_id.

I think the solution should be using all_recude_multi_gpu outside of Foo.forward. I’ve briefly gone though DistributedDataParallel, seems it’s using hooks for reduce ops. However, the all_reduce op is in the middle of the Foo's forward process, so register_forward/backward_hook may not help.

Can anyone give suggestions? Thank you a lot~