Decentralized asynchronous stochastic gradient descent implementation

Hi, I want to try out decentralized stochastic gradient descent in an asynchronous setting (ad-psgd). What I want to happen is that for example, if I have some distributed worker nodes, I want to split a dataset across them and these workers are connected in communication to each other in some sort of network (examples from papers are often ring networks). The training algorithm is as follows:

  1. each worker will asynchronously compute the gradient with their dataset while at the same time gathering the weights of the model from one random worker which it is connected to and averaging it along with the worker’s own original weights. Call this W_avg

  2. after computing the gradients, the worker will calculate a new weight like so: W_new = W_avg - (learning rate) * (gradient)

  3. Set the new weight of the worker to be W_new

  4. repeat 1-3 for whole dataset

And this should be done asynchronously meaning that these workers will not wait at all, they will keep on repeating the algorithm until finished and it doesn’t matter if their neighbor weights are outdated or anything like that.

I would also like to try regular d-psgd which is where the worker averages all their neighbor’s weights.

Is there anything already like that in torch fsdp or ddp? If not, what can I look into to implement this? I am not very familiar with torch distributed implementations.

The algorithm comes from: https://arxiv.org/pdf/1710.06952v3.pdf