Get per GPU gradients for DDP model

Hi, for models wrapped with DDP, I suspect gradients across all GPUs are accumulated automatically after the backward() call and before the optimizer.step() call.

How do I get the per device gradients before gradients are synchronized across devices?

Say there are N GPUs, each of these N copies of model run

model_out = ddp_model(model_input)
loss = loss_fn(model_out)

before calling


is there a way I get the per GPU gradients via something like


before the gradients are synchronized? Thanks.

DDP comm hook should be useful for implementing what you need: DistributedDataParallel — PyTorch 2.1 documentation