a = torch.rand(3).requires_grad_()
l = a.sum()
torch.distributed.all_gather([a,b,c], a)
l.backward() will trigger an error say
one of the variables needed for gradient computation has been modified by an in-place operation
I assume this is because all_gather do in-place change for all a,b,c.
But why? As a is emitted by the current process, it’s not necessary to change it and cause this issue. Is there any consideration behind this all_gather behavior?
We probably can add a shortcut to avoid changing a in this case, but I am not sure if that is a good idea, because that will make all_gather have different behavior depending on underlying storage. Consider two cases.
Case 1:
x = empty_like(a)
torch.distributed.all_gather([x,b,c], a)
In this case, we would still need to write data from a to x, right?
Case 2:
x = a.view(...) # say change the stride
torch.distributed.all_gather([x,b,c], a)
In this case, x will share the storage with a, but using a different element layout, so we would need to write into x.
To address the above problems, we probably can detect and only skip inplace write if x shares the same storage and meta with a. However, the concerns are 1) does the extra overhead worth it? 2) will the disparity in all_gather’s behavior confuse users?
This PR might be relevant. It’s trying to avoid a flatten.