Does DistributedOptimizer support zero_grad and lr_scheduling?

I try to use distributed_rpc to implement parameter server. I would like to use zero_grad and lr_scheduling feature in the trainer. But it seems that the DistributedOptimizer does not support this. Is there any workaround?

Hey @Kunlin_Yang

zero_grad

The reason DistributedOptimizer does not provide a zero_grad API is because the gradients of each backward pass is stored in its own dedicated context (instead of in param.grad), and the context will be cleared when exiting the with dist_autograd.context() as context_id: scope. So zero_grad is not needed here. We can certainly add it if necessary, and it won’t be too hard to implement it in the application code either. E.g.,

def zero_grad(pr, context_id):
    dist_autograd.get_gradients(context_id)[pr.local_value()].zero_()


with dist_autograd.context() as context_id:
    # omitting forward-backward-optstep here
    futs = []
    for pr in param_rrefs:
        futs.append(rpc.rpc_async(pr.owner(), zero_grad, args=(pr, context_id)))
    [fut.wait() for fut in futs]

Or is there a different reason you would like to use the zero_grad API?

lr_scheduling

Currently, there is no distributed implementation or lr scheduling yet. I created an issue to track this: https://github.com/pytorch/pytorch/issues/38548

For now, you will need to do that using the raw RPC API. You can access the RRefs of the remote optimizers through DistributedOptimizer().remote_optimizers, so it can be sth like:

def create_lr_schheduler(opt_rref):
     # create and return lr_schheduler

def lrs_step(lrs_rref):
    lrs_rref.local_value().step()

opt = DistributedOptimizer(...)
lrs_rrefs = []
for opt_rref in opt.remote_optimizers:
    lrs_rrefs = rpc.remote(opt_rref.owner(), create_lr_schheduler, args=(opt_rref,))

with dist_autograd.context() as context_id:
    # omitting forward-backward-optstep here
    futs = []
    for lrs_rref iin lrs_rrefs:
        futs.append(rpc.rpc_async(lrs_rref.owner(), lrs_step, args=(lrs_rref,)))
    [fut.wait() for fut in futs]

If you are using master branch, the above code can be simplified with RRef.rpc_async() API.

2 Likes