Prevent reducing gradients in .backward() across processes DDP

I have a use case where I need to calculate the gradients of parameters of a NN on a per sample basis, i.e., batch size of 1, and I want params.grad after loss.backward(). I want to speed this up with DistributedDataParallel, so I have each process take in a batch of 1 and perform the calculation.

The problem is loss.backward() with DDP automatically reduces the gradients among processes, so that the gradients I get from params.grad for each process will be the same. What I need are the individual gradients without reducing (batch size 1, rather than averaged over the N x 1 data points over N GPUs).

Is there a way to disable DDP from automatically reducing the backward pass? Note that I am not actually optimizing the loss or anything, I am just calculating the first order derivative for a trained model (to calculate the Fisher Information Matrix if anyone is wondering).


I think DistributedDataParallel.no_sync() should work for your use case:

A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.

Thank you, this looks like exactly what I need.