Hi. I have a network f that take a batch of input pair of image and label (x,y) and compute a cross entropy loss. I understand that if I use nn.CrossEntropyLoss(), the loss will be averaged over every input of the batch. If I access any weight (let say module.conv1.weight), I can obtain the gradient of the loss that is the average of the cross entropy over the batch with respect to the parameter module.conv1.weight by accessing module.conv1.weight.grad.
My question is the following. I would like to be able to compute the mean of the squared gradient of the loss with respect to the parameter over the batch instead of the mean of the gradient itself.
In other words, instead of having module.conv1.weight.grad that contains the mean of the gradient of the loss with respect to every element of the batch, I would like to extract a variable that would contain the mean of the squared gradient of the loss, i.e. a vector g whose entries g_i are (module.conv1.weight.grad_i)^2.
Is it possible to do that?