All_gather trigger unnecessary in-place change

Hi. For example, for the code snippet

a = torch.rand(3).requires_grad_()
l = a.sum()
torch.distributed.all_gather([a,b,c], a)

l.backward() will trigger an error say

one of the variables needed for gradient computation has been modified by an in-place operation

I assume this is because all_gather do in-place change for all a,b,c.
But why? As a is emitted by the current process, it’s not necessary to change it and cause this issue. Is there any consideration behind this all_gather behavior?


We probably can add a shortcut to avoid changing a in this case, but I am not sure if that is a good idea, because that will make all_gather have different behavior depending on underlying storage. Consider two cases.

Case 1:

x = empty_like(a)
torch.distributed.all_gather([x,b,c], a)

In this case, we would still need to write data from a to x, right?

Case 2:

x = a.view(...) # say change the stride
torch.distributed.all_gather([x,b,c], a)

In this case, x will share the storage with a, but using a different element layout, so we would need to write into x.

To address the above problems, we probably can detect and only skip inplace write if x shares the same storage and meta with a. However, the concerns are 1) does the extra overhead worth it? 2) will the disparity in all_gather’s behavior confuse users?

This PR might be relevant. It’s trying to avoid a flatten.