Can I get gradients of network for each sample in the batch?

Hi! I am implementing a simple algorithm to compute squared variance of gradients. For this, I need to compute gradients of the network for each sample in my batch. As far as I understand, loss.backward() computes gradients for each sample and then averages them. Is there any flag which outputs gradients for each sample before their averaging? Thanks!

As far as I know you cannot get specific grads unless you backpropagate for each loss term before averaging.
You can get element-wise loses by passing reduce=False
Check any loss’ docs to see

So, in short,

criterion = nn.Loss(reduce=False)

losses = criterion(pred,gt)
for loss in losses:
1 Like

Hi @JuanFMontesinos ! Thank for reply! Yes, I am aware of loss “reduction”. However, I need gradients of the network itself. At the moment, I am forwarding and backwarding each sample, which is very slow.

As I aforementioned, you can backpropagate each element before reduction. At least this will speed up forward.
For backward, it has to be 1-by-1.
I think the whole system is built under the idea that you can access gradients doing param.grad, that’s why it would be a mess if you could filter out gradients from any element. Besides, the system is not aware of what “loss” or a batch is is. It just follows chain rule.

Just a minor, reduce h/b deprecated. Use reduction instead:

  • reduction ( string , optional ) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' . 'none' : no reduction will be applied, 'mean' : the sum of the output will be divided by the number of elements in the output, 'sum' : the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction . Default: 'mean'

Probable you are looking for option: none

If you use simple NN, you can use tricks like the one mentionned here to reuse computations.

Hi @albanD
I have a customised layer as will in my NN. So is there anyway to still use the trick that you mentioned please?