Will "dist.all_gather" break the auto gradient graph?

if we use torch.nn.parallel.gather to collect data from other GPUs, then we do some operations on the gathered data A, the gradient will go back to the original array( since it’s from the ‘nn’ package, that’s what supposed to happen). We can check grad_fn of A, it says CopyBackwards, that’s what we expect.

but it seems the all_gather from torch.distributed package is different. it simply copies the data here and doesn’t have any grad_fn function. Does this mean if we use this function, we have to manually compute gradients for the gathered data?

related post:


Hi, I ended up writing the backward function manually. I used all_gather function from nccl, so Pytorch won’t take care of that of course. I hope they can add support for that in the future.

i also had similar questions, but finally i find the current design is good. For all_gather, the gradient will not be propagated back to other devices, but the gradient for current device can be calculated correctly. Since each device calculates the loss and propagates the gradient for each device, the final results are correct. For gather, i assume the non-0 device will not calculate the loss and only device 0 does the loss calculation. In that case, the gradient has to be sent back to non-0 devices.

1 Like

Do you have an example of how to do that ?
Thanks!

Wrote a blog about a way to use all_gather, without the need to calculate the gradient

Here is an implementation of all_gather with gradient back-prop: vlkit.ops.distributed.all_gather().

2 Likes

Hi, this link doesn’t work. I wanted to implement something similar. Thanks.

1 Like

Hi the link is now open. Alternatively you can use vlkit.ops.all_gather.