I have a use case where I need to calculate the gradients of parameters of a NN on a per sample basis, i.e., batch size of 1, and I want
loss.backward(). I want to speed this up with
DistributedDataParallel, so I have each process take in a batch of 1 and perform the calculation.
The problem is
loss.backward() with DDP automatically reduces the gradients among processes, so that the gradients I get from
params.grad for each process will be the same. What I need are the individual gradients without reducing (batch size 1, rather than averaged over the N x 1 data points over N GPUs).
Is there a way to disable DDP from automatically reducing the backward pass? Note that I am not actually optimizing the loss or anything, I am just calculating the first order derivative for a trained model (to calculate the Fisher Information Matrix if anyone is wondering).