Hey @Saurav_Gupta1, the usage on torch.distributed.all_gather
looks correct to me. One thing I wanna mention is that torch.distributed.all_gather
is not an autograd function, so running backward on gathered tensors (i.e., tmp
in your code) won’t reach the autograd graph prior to the all_gather
operation. If you really need autograd to extend beyond comm ops, you can try this experimental autograd-powered all_gather
.