Dear @mrshenli,
I have noticed that your team/colleague released a new tutorial on the parameter server using the RPC framework (rpc_param_server_tutorial). I really appreciate the example with detailed and helpful explanations, and it seems to me that it can work with multiple trainers accessing to the same parameter server. I think the code below makes sure there is only one parameter server can be created by the trainers.
# The global parameter server instance.
param_server = None
# A lock to ensure we only have one parameter server.
global_lock = Lock()
def get_parameter_server(num_gpus=0):
"""
Returns a singleton parameter server to all trainer processes
"""
global param_server
# Ensure that we get only one handle to the ParameterServer.
with global_lock:
if not param_server:
# construct it once
param_server = ParameterServer(num_gpus=num_gpus)
return param_server
def run_parameter_server(rank, world_size):
# The parameter server just acts as a host for the model and responds to
# requests from trainers.
# rpc.shutdown() will wait for all workers to complete by default, which
# in this case means that the parameter server will wait for all trainers
# to complete, and then exit.
print("PS master initializing RPC")
rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
print("RPC initialized! Running parameter server...")
rpc.shutdown()
print("RPC shutdown on parameter server.")
However, when it comes to Distributed Autograd, forward, and back passes using the training loop below:
def run_training_loop(rank, num_gpus, train_loader, test_loader):
...
for i, (data, target) in enumerate(train_loader):
with dist_autograd.context() as cid:
model_output = net(data)
target = target.to(model_output.device)
loss = F.nll_loss(model_output, target)
if i % 5 == 0:
print(f"Rank {rank} training batch {i} loss {loss.item()}")
dist_autograd.backward(cid, [loss])
# Ensure that dist autograd ran successfully and gradients were
# returned.
assert remote_method(
ParameterServer.get_dist_gradients,
net.param_server_rref,
cid) != {}
opt.step(cid)
print("Training complete!")
print("Getting accuracy....")
get_accuracy(test_loader, net)
How would I make sure there are no concurrency issues? For example, if you have two trainers and have a situation where one trainer is doing the forward propagation and the other is doing the backward pass, how to make sure the two processes are not conflicting with each other?
Thanks,