How to modify all_reduce operator in PyTorch?

Hi all,

I’m suffering from the hugh bandwidth requirement when I train my model in multiple nodes, so considering to apply this. However, I have no idea that what part of the pytorch code I have to replace. I guess that I need to make another torch.distributed.reduce_op to replace torch.distributed.reduce_op.SUM, but I couldn’t find any guideline to implement another operator from official docs or tutorials. Could you let me know what part of pytorch code I have to refer to?

Thanks,
Jinserk

I went through paper quickly so correct me if I am wrong. They are not implementing a network communication process, they are not implementing a new communication semantic, they are pre-processing gradients before reducing them. If you are using DistributedDataParallel, you can pre-process gradients before doing backward propagation. If you want to also change distribution behavior like skipping gradient sharing every two epoch etc., you can implement DistributedDataParallel yourself by following distributed tutorial with operations like torch.distributed.reduce_op.

Thanks for quick and kind reply @enisberk! I wonder that if I use DistributedDataParallel, it performs all_reduce automatically or not? I asked the same question here but want to ask here again and will remove the thread if it is clarified.

Are you asking if you need to call something like that:
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
to be able to use DistributedDataParallel ?
No, you do not need to. You can see imagenet example

Thanks for clarifying. Then how can I pickup only big gradients to make them sparse, as the paper proposed, if the DistributedDataParallel controls all I/O internally?

I think so, it is handling IO internally Docs
Check out the tutorial to see how Arnold implementing it, so you can use it to modify gradient before reducing them.