I have a use case where I need to calculate the gradients of parameters of a NN on a per sample basis, i.e., batch size of 1, and I want `params.grad`

after `loss.backward()`

. I want to speed this up with `DistributedDataParallel`

, so I have each process take in a batch of 1 and perform the calculation.

The problem is `loss.backward()`

with DDP automatically reduces the gradients among processes, so that the gradients I get from `params.grad`

for each process will be the same. What I need are the individual gradients without reducing (batch size 1, rather than averaged over the N x 1 data points over N GPUs).

Is there a way to disable DDP from automatically reducing the backward pass? Note that I am not actually optimizing the loss or anything, I am just calculating the first order derivative for a trained model (to calculate the Fisher Information Matrix if anyone is wondering).

Thanks!