When doing gradient accumulation, we typically reduce the loss terms via a sum or mean:
for _ in range(N): loss_i = my_loss(x, y) / batch_size loss_i.backward() optimizer.step()
Here we have a mean over the loss terms, so the “final” loss is implicitly
(1/batch_size) * sum(loss_i)
Similarly a sum reduction would just require removing the division by batch size.
What I’m wondering is, is it possible to have a less trivial reduction with gradient accumulation? I would like to, say, have a final loss of
sqrt(sum(loss_i ** 2)) i.e. an L2 reduction of the losses.
Thanks in advance.