I have a use case where I need to calculate the gradients of parameters of a NN on a per sample basis, i.e., batch size of 1, and I want params.grad
after loss.backward()
. I want to speed this up with DistributedDataParallel
, so I have each process take in a batch of 1 and perform the calculation.
The problem is loss.backward()
with DDP automatically reduces the gradients among processes, so that the gradients I get from params.grad
for each process will be the same. What I need are the individual gradients without reducing (batch size 1, rather than averaged over the N x 1 data points over N GPUs).
Is there a way to disable DDP from automatically reducing the backward pass? Note that I am not actually optimizing the loss or anything, I am just calculating the first order derivative for a trained model (to calculate the Fisher Information Matrix if anyone is wondering).
Thanks!