Will "dist.all_gather" break the auto gradient graph?

(yy) #1

if we use torch.nn.parallel.gather to collect data from other GPUs, then we do some operations on the gathered data A, the gradient will go back to the original array( since it’s from the ‘nn’ package, that’s what supposed to happen). We can check grad_fn of A, it says CopyBackwards, that’s what we expect.

but it seems the all_gather from torch.distributed package is different. it simply copies the data here and doesn’t have any grad_fn function. Does this mean if we use this function, we have to manually compute gradients for the gathered data?

related post: