Differentiable communication - Distributed Model Parallel

Hello everyone, I created this pip package that includes differentiable versions of scatter/gather/send/recv so that pytorch’s autograd can backpropagate through those. I thought I should share. I haven’t thoroughly tested it so apologies if something breaks. Contributions are welcome!

There is some example code here: https://github.com/ag14774/diffdist/blob/master/diffdist/testing.py


Very nice, thank you for sharing!

1 Like