Hello everyone, I created this pip package that includes differentiable versions of scatter/gather/send/recv so that pytorch’s autograd can backpropagate through those. I thought I should share. I haven’t thoroughly tested it so apologies if something breaks. Contributions are welcome!
There is some example code here: https://github.com/ag14774/diffdist/blob/master/diffdist/testing.py