How could I update attributes of a RRef?

I am studying the tutorial on parameter server parameter server. I would like to update the grad of parameters during training (i.e. between dist_autograd.backward(cid, [loss]) and opt.step(cid)). How could I achieve this?

Hey @Kunlin_Yang, say, param_rref is the RRef of the corresponding parameter, you can do sth like the following:

import torch.distributed.autograd as dist_autograd

def update_grad(param_rref, cid):
    grad = dist_autograd.get_gradients(cid)[param_rref.local_value()]
    # do inplace update on the grad
    with torch.no_grad():

# following is a toy training step
with dist_autograd.context() as cid:
    loss = model(inputs).sum()
    dist_autograd.backward(cid, [loss])
    rpc.rpc_sync(param_rref.owner(), update_grad, args=(param_rref, cid))

Curious, what is the reason to manually update the grad instead of using the grad given by the backward pass?

1 Like

Thanks! That’s very helpful. I am doing federated learning research, so I need to add noise to gradients of local worker.

1 Like