How do I modify the aggregate function of a batch in back propagation?

In the batch gradient method, the parameters are updated by simply averaging the gradients in the batch direction, but in my particular situation, I need to change the aggregation function (average → weighted sum, etc.).

How can I make this happen? As far as I can tell, it’s not possible with register_full_backward_hook.

1 Like

After calling, .backward(), I think you can iterate over model parameters and collect their gradients. Then you can modify them, and call the optimizer.step() as usual. That’s what comes to my mind at least.

And if you’re simply weighting training data instances, you can multiply the loss function rather than directly modifying the gradients.

Thanks for the reply, TinfoilHat0!

Is there a way to get the gradient for each element of the batch that way? As I understand it, weight.grad contains the values after applying the aggregate function to the batch.

Also, I didn’t explain it well enough, but the aggregate function to be modified is more complex, and cannot be achieved by simply multiplying the loss.

Yeah, I think you’re right. weight.grad contains what’s computed after batch gradients are averaged.

I think you can explicitly get the gradient of a particular input by

torch.autograd.grad(out, x)[0]

where out is the model’s output and x is the particular input you’re interested in.

Thank you, TinfoilHat0!

That method seems to be no different than using .backward() and weight.grad.
Of course, it is possible if the batch size is set to 1, but that would be too inefficient.

What I want to do now is to process batch by batch with forward propagation, and get the gradient for the weights of each element in the batch.

Yeah I was thinking of the case for computing the gradient wrt each input in the batch individually. Sadly, I can’t think of an efficient way of doing this.

1 Like

averaging of losses is avoidable with reduction=‘none’ argument (and manual reduction with weighted sum etc.). at individual parameter level, it is exotic and more complex, for example if you have a linear layer: x.matmul(W), shapes are (batch,in) @ (in,out), W has no batch dimension, thus gradient summation is implicit in backprop tensor formulas (another matmul in this case).

If I can’t find a more efficient way, I’ll use that one.

Thanks, TinfoilHat0!

Thanks for the reply, googlebot!

So it looks like there is no way to get the gradient of W for each element of the batch directly, and we need to recover the value of the gradient of W from the gradient of the input/output.

if you have some intermediate layer

layer(x : Tensor[b,in], p : Tensor[*]) → y : Tensor[b,out]

you don’t have to collect gradients wrt p by batch element to rescale p.grad - you can rescale output gradient (2d view in this case) rowwise, with scalar weights. But note:

  1. loss reduction=‘none’ + weighted sum of losses approach has the same effect, applied to the whole network.
  2. I have doubts about mixing different sample weights in one backpropagation, but I’ll assume you know what you’re doing
  3. this reweighting sticks for earlier layers, i.e. for above function gradient wrt x is affected too. This can be counteracted with another reweighting (register_hook or similar approach).

Thank you, googlebot!

I’m sorry for my explanation, but actually the aggregate function I want to apply is more complex than weighting, so it can’t be achieved by weighting the losses.