I am studying the tutorial on parameter server parameter server. I would like to update the grad of parameters during training (i.e. between dist_autograd.backward(cid, [loss])
and opt.step(cid)
). How could I achieve this?
Hey @Kunlin_Yang, say, param_rref
is the RRef of the corresponding parameter, you can do sth like the following:
import torch.distributed.autograd as dist_autograd
...
def update_grad(param_rref, cid):
grad = dist_autograd.get_gradients(cid)[param_rref.local_value()]
# do inplace update on the grad
with torch.no_grad():
grad.add_(1)
# following is a toy training step
with dist_autograd.context() as cid:
loss = model(inputs).sum()
dist_autograd.backward(cid, [loss])
rpc.rpc_sync(param_rref.owner(), update_grad, args=(param_rref, cid))
opt.step(cid)
Curious, what is the reason to manually update the grad instead of using the grad given by the backward pass?
1 Like
Thanks! That’s very helpful. I am doing federated learning research, so I need to add noise to gradients of local worker.
1 Like