I want to manually update the gradients of a model from the comm_hook.
Inside the hook, we can access parameters and gradients. However, I don’t know how to map GradBucket parameters and gradients to the model version.
def comm_hook(state: ModuleNet, bucket: torch.distributed.GradBucket): -> torch.futures.Future[torch.Tensor]:
gradient_tensors = bucket.gradients()
model_params = bucket.parameters()
# Update ModuleNet model gradients from the GradBucket.
# Main Logic
model = ModuleNet()
ddp_model = DDP(model)
ddp_model.register_comm_hook(model, comm_hook)
output = ddp_model(input)
loss = ..
loss.backward()
optimizer.step()