Compute squared gradient

Hi. I have a network f that take a batch of input pair of image and label (x,y) and compute a cross entropy loss. I understand that if I use nn.CrossEntropyLoss(), the loss will be averaged over every input of the batch. If I access any weight (let say module.conv1.weight), I can obtain the gradient of the loss that is the average of the cross entropy over the batch with respect to the parameter module.conv1.weight by accessing module.conv1.weight.grad.

My question is the following. I would like to be able to compute the mean of the squared gradient of the loss with respect to the parameter over the batch instead of the mean of the gradient itself.

In other words, instead of having module.conv1.weight.grad that contains the mean of the gradient of the loss with respect to every element of the batch, I would like to extract a variable that would contain the mean of the squared gradient of the loss, i.e. a vector g whose entries g_i are (module.conv1.weight.grad_i)^2.

Is it possible to do that?

Thanks

Hi,

Pytorch computes the gradients wrt the loss you computed. So if your loss is the average, then the gradients will be the average as well.

If you don’t average the loss, then you will get as many losses as the batch size. And you will need to run as many backwards as the batch size to get all these gradients.

There’s a trick from Ian Goodfellow to do this in a single pass: https://arxiv.org/abs/1510.01799

This post gives PyTorch recipe – Efficient Per-Example Gradient Computations

1 Like

Thanks a lot for this answer!