I think the disadvantage in using the
sum
reduction would also be that the loss scale (and gradients) depend on the batch size, so you would probably need to change the learning rate based on the batch size. While this is surely possibly, amean
reduction would not make this necessary.
But if your dataset is 10 elements, then with batch size 10, 1 epoch is 1 optimizer step that is the average gradient over your entire dataset, whereas batch size 1 is 10 optimizer steps from each element of your dataset.
With mean reduction you would need to train for 10 epochs to do the same number of similarly sized optimizer steps. You are effectively training on a dataset that is 1 / batch size times as big.