How to decrease the weight of a mini-batch?

The problem comes from this paper: Decoupling “when to update” from “how to update”.
Regardless the detail of the paper, this implementation will resulting mini-batches has different sizes.
For example, if we set the mini-batch size to 128. Some of the batches could have only 10 or even zero samples.
If there’s zero sample in a batch, then we can just skip the update at this iteration.
However, if the sample in a batch is much samller, let’s say 10% of the 128, then we have to scale down the update of the gradient on this iteration.
The idea is that if there is a less important batch, it shouldn’t have the same impact as a full-size batch.
If we reduce the learning rate 10 time on this paticular batch, we can achieve this effect, and use the normal learning rate on the full-size batches.
But how can we achieve this in PyTorch?
I hope my description is clear. Thank you so much for any suggestions.

you can reduce output of loss function yourself or use reduction='sum' instead of reduction='mean'. this way your orverall loss is proportional to number of samples in each batchs.

  • reduction (string , optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

That’s a very intersting solution to the problem. I only saw the weight parameter in the loss func, but don’t think it fits this situation. If it works, it will scale down nicely. I give a try and let you know the result. Thank you.

When nn.CrossEntropyLoss(reduction='sum') is used, the model simple does not converge with torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9).
Once swith back to mean, the same model converge within 5 epochs tested on MNIST.
It seems the sum option is only used in evaluation for computing the overall loss, and nobody use it on training.

@sunfishcc it’s because of large learning rate.
When you’re using mean, you’re effectively dividing learning rate by batch size.
use smaller learning rate.

@mMagmer I see. It works once I changed lr to 0.0001.

I also tried to manually adjust lr when then batch size is low: optimizer.param_groups[0]['lr'] *= 0.1, and then reset it after calling the step() function. This seems like a hack, not sure about if there’s any side effect.

Thank you so much for your suggestion.