I was going through how weight_decay is implemented in optimizers, and it seems that it is applied per batch with a constant that ideally should be for the whole loss. For instance, if training a model with CE loss, one would expect the formulation to be:

mean(CE, X, Y) + weight_decay * norm(W, 2)

Since models are trained in batches, the first term is loss per batch. However, the second term should ideally also be scaled by the number of batches. For instance, if training a model with batch size 64 will penalize weight_decay X times, training with a batch size of 256 will do it X/4 times, which seems a bit odd?

Is this an implementation bug, or is it assumed that the developer will scale weight_decay while passing it to the optimizer (which kind of makes sense, since the optimizer does not know how many batches are present, but it would be nice to mention this somewhere in the documentation)?

Hey! The CE loss considered is the mean value and not the sum of the loss values for each example. Thus the scale of the first term remains the same (with changes in batch size). So the second term need not be scaled by the number of batches.

Correct, but in essence the model is being penalized for larger weights much more frequently if the batch size is small. Over one epoch, a model with BS=16 might get penalized 32 times, whereas BS=128 would make that only 4 times. I get that the loss considers mean and we are doing batch-gradient descent after all, but should there not be some sort of scaling for the regularization as well?

Oops! Sorry for misunderstanding your point earlier. I think what you are raising ties to the difference between weight decay and L2 regularization? And that PyTorch currently implements L2 regularization but it’s still labelled as “weight decay”?

No. I guess what you mentioned earlier makes sense now, since per epoch the model only really sees one batch of data, and it makes sense for the weight decay constant to not vary.