if we use torch.nn.parallel.gather to collect data from other GPUs, then we do some operations on the gathered data A, the gradient will go back to the original array( since it’s from the ‘nn’ package, that’s what supposed to happen). We can check grad_fn of A, it says CopyBackwards, that’s what we expect.
but it seems the all_gather from torch.distributed package is different. it simply copies the data here and doesn’t have any grad_fn function. Does this mean if we use this function, we have to manually compute gradients for the gathered data?
i also had similar questions, but finally i find the current design is good. For all_gather, the gradient will not be propagated back to other devices, but the gradient for current device can be calculated correctly. Since each device calculates the loss and propagates the gradient for each device, the final results are correct. For gather, i assume the non-0 device will not calculate the loss and only device 0 does the loss calculation. In that case, the gradient has to be sent back to non-0 devices.