[solved] How to use dist.broadcast right in distributed setting?

I want to implement parameter server architecture with PyTorch. I use the dist.reduce gather the parameters from all workers, and want to average the parameters in parameter server, and use dist.broadcast to update the paramegers in each workers. My code is

############### for the parameters server
def update_ps(model, group):
size = float(dist.get_world_size() - 1)
for param in model.parameters():
new_data = param.data
dist.reduce(new_data, dst=0, op=dist.reduce_op.SUM, group=group)
param.data = new_data/size
dist.broadcast(param.data, src=0, group=group)
print(“receive and send”)
##############for the worker
def update_worker(model):
for param in model.parameters():
p = param.data.cpu()
dist.recv(p, src=0)
param.data = p.cuda()
print(“received”)

Although the code can work, yet I find that parameters don’t update. And I find there is no communication cost between workers and parameter server.

Is it the code right?

> Preformatted text

I fixed the problem.