I believe this is because it’s trying to backpropogate through torch.distributed.all_reduce
, which is not possible at the moment.
Instead, you can replace torch.distributed.all_reduce
with torch.distributed.nn.functional.all_reduce
, which has the backward
method implemented. See the suggestion at for more details and caveats: Do gradients propagate through all_reduce & all_gather? - #2 by wanchaol
Also refer to torch.distributed.nn.all_reduce incorrectly scales the gradient · Issue #58005 · pytorch/pytorch · GitHub for caveats around the (in-)correctness of the gradient if reduce ops is not sum with torch.distributed.nn.functional.all_reduce
.