Can I find grad of all parameters before batch avg


I can access gradients of the parameters of my network like so:

my_param = next(model.parameters())
my_grad = my_param.grad

Those gradients are, if I understand correctly, averaged over all the inputs, i.e. if we have 10 images in our input batch, those gradients are averaged across those 10 input images. Would it be possible to get the gradients for each input separately? For example, if my layer has 6x3x5x5 parameters, normally grad would have the shape of 6x3x5x5. However, if my batchsize is 4, would it be possible to obtain gradients of size 4x6x3x5x5, i.e. gradients before averaging?

A poor solution might be reducing batchsize to 1 and calculating the gradients for each single input separately, but I assume there is a proper way of doing it.

I think its all about the loss I am using, because the loss is averaging everything, not the grad…But then if I do:

criterion = nn.CrossEntropyLoss(reduce=False)
loss = criterion(outputs, labels)

then I get an error:

RuntimeError: grad can be implicitly created only for scalar outputs

How do I get the gradients for each input separately then?

Hi Matt!

This is, in essence, the right solution. Autograd performs “vector-Jacobian
products” to link the steps in the chain rule for differentiation together,
which is to say that, in essence, it is computing the gradient for a single
scalar in a given backward pass.

So if you want the gradient of the loss for each batch element separately,
you have to perform a separate backward pass for each batch element.

This error message is autograd telling you that it will only compute the
gradient for a single scalar at a time (or, more precisely, compute the
vector-Jacobian product for a single specific vector at a time).


K. Frank

I believe you are looking for per sample gradients of the Model.
Unfortunately, torch.autograd does not enable this behavior.
You will have to use functorch

functorch is JAX-like composable function transforms for PyTorch.

I guess you will find this helpful.
I too am working on something similar. You can also try using JAX for per-sample-gradients and found this to be faster than functorch.

Thanks a lot guys, I will use the naive approach and switch to functorch or JAX when I have speed issues.