Calculate individual gradients for each example in mini-batch?

Say I have an the output of y=model(x) where y.size() is [batch_dimensions, 1]. I would then like to calculate the individual gradients of each example in the output with respect to the model.parameters(). Naively, I could do:

grads = [torch.autograd.grad(y[batch], model.parameters(), retain_graph=True) batch in range(batch_dimensions)]

but this is serial and slow.

I know I can, however, do:

grads = torch.autograd.grad(y, model.parameters(), grad_outputs=torch.ones(y.size()))

which is an awful lot faster. However, this seems to sum the gradients across the batches, which is not quite what I want. How could I possibly achieve what I require efficiently?


Try this: