I try to use distributed_rpc to implement parameter server. I would like to use zero_grad and lr_scheduling feature in the trainer. But it seems that the DistributedOptimizer does not support this. Is there any workaround?
Hey @Kunlin_Yang
zero_grad
The reason DistributedOptimizer
does not provide a zero_grad
API is because the gradients of each backward pass is stored in its own dedicated context (instead of in param.grad
), and the context will be cleared when exiting the with dist_autograd.context() as context_id:
scope. So zero_grad
is not needed here. We can certainly add it if necessary, and it won’t be too hard to implement it in the application code either. E.g.,
def zero_grad(pr, context_id):
dist_autograd.get_gradients(context_id)[pr.local_value()].zero_()
with dist_autograd.context() as context_id:
# omitting forward-backward-optstep here
futs = []
for pr in param_rrefs:
futs.append(rpc.rpc_async(pr.owner(), zero_grad, args=(pr, context_id)))
[fut.wait() for fut in futs]
Or is there a different reason you would like to use the zero_grad
API?
lr_scheduling
Currently, there is no distributed implementation or lr scheduling yet. I created an issue to track this: https://github.com/pytorch/pytorch/issues/38548
For now, you will need to do that using the raw RPC API. You can access the RRefs of the remote optimizers through DistributedOptimizer().remote_optimizers
, so it can be sth like:
def create_lr_schheduler(opt_rref):
# create and return lr_schheduler
def lrs_step(lrs_rref):
lrs_rref.local_value().step()
opt = DistributedOptimizer(...)
lrs_rrefs = []
for opt_rref in opt.remote_optimizers:
lrs_rrefs = rpc.remote(opt_rref.owner(), create_lr_schheduler, args=(opt_rref,))
with dist_autograd.context() as context_id:
# omitting forward-backward-optstep here
futs = []
for lrs_rref iin lrs_rrefs:
futs.append(rpc.rpc_async(lrs_rref.owner(), lrs_step, args=(lrs_rref,)))
[fut.wait() for fut in futs]
If you are using master branch, the above code can be simplified with RRef.rpc_async()
API.