I am trying to train a binary classifier with Pytorch v1.3.1 on time series data using a LSTM with CrossEntropyLoss as loss function. Since the dataset is very unbalanced (~98% Class 0, ~2% Class 1), I want to apply weights [1., 55.] in the loss function. The training is done in batches of 24. Using a random example, where one of the 24 samples is from class 1 and the 23 remaining from class 0, I get the following losses:
reduction = ‘mean’, no weights: 0.7963
reduction = ‘sum’, no weights: 19.11
reduction = ‘none’, no weights, manually summing losses: 19.11
I would expect that the averages loss times the batch size equals the summed loss. Here, 24*0.7963 equals 18.96 which is not identical but close to the summed loss.
reduction = ‘mean’, weights in loss fct: 0.658
reduction = ‘sum’, weights in loss fct: 51.48
reduction = ‘none’, weights in loss fct, manually summing losses: 51.48
Now 0.658 * 24 (=15.79) is totally different from 51.48. Is this the intended behavior?
How does Pytorch compute the average? Is there any difference in training behavior when using the sum instead of the mean reduction?
As you can see the results of the various “reductions” all tie out (when
the average is understood as a weighted average).
Not a lot. In the unweighted case, the loss using “sum” will be nBatch
times bigger (24, in your example) than with “mean.” In effect, this
means that you would be using a learning rate 24 time bigger.
Learning rate does matter, so you would probably want to use a
proportionately smaller learning rate if you use “sum” rather than
In the weighted case, the ratio between “sum” and “mean” will depend
on the number of “1” target values in your batch (as illustrated by the
above script). So your effective learning rate would vary from batch
to batch based on the number of "1"s. But I think this batch-to-batch
variation would largely average away over multiple batches.
One last note: Because you’re training a binary classifier, you would
probably be modestly better off using BCEWithLogitsLoss and its pos_weight argument. (To do this you would change the last layer
of you network to emit only a single prediction (per sample in the
batch), rather than two.)