Say I have a training situation where I have two loss functions defined and want to update my gradients by backward pass over both. If I simply specify:
err1 = criterion1(...)
err2 = criterion2(...)
err = err1 + err2
Is it possible that one gradient update will overpower the other so that, in effect, only one loss really gets optimized since the other gradient update is so much smaller?
If so, can I normalize the gradients of the two before combining them and make sure they have similar impact on my update?
The short answer is yes.
One way to balance the contribution of each loss is to weight them. In this way you can control for which loss you want to prioritise. Something similar happens when you use
weight_decay in your optimiser, which essentially scales the loss by the given value.
Just a weighted sum, like:
k = 0.2
err = k * err1 + (1 - k) * err2?
How do I determine a good weight_decay value? Is this based on the norm of the gradient updates of one compared to that of the other?
Actually, I’m realizing the
weight_decay might not work because it will decay the weight for both losses, no?
Yeah so weight_decay is used to add a regularization penalty to model weights over time, to keep model weights small on average and prevent overfitting. In answer to your second question, yes, weight decay will generally tend to decay the weights for both losses, unless it was implemented in a custom manner to apply only to one loss. The importance here is that weight_decay is used to adjust values’ weightings over time during training.
(see How does SGD weight_decay work?)
What you are likely looking for is instead a weighting function that combines two scalar values, probably the same way at all points during training.
err = k * err1 + (1-k)* err2
One way to do this is a simple weighted sum, as shown above. Practically, this weighting value k must be selected by trying a number of candidate values.There are other ways to combine these values as well, such as multiplication:
err = err1*err2
or geometric mean (this is a commonly used technique for combining the error metrics of precision and recall into a single metric called F1 score).
err = (2* err1 *err2) / (err1 + err2)
The short of it is, there’s not an easy way to see how different weightings of the different loss components will translate into performance gains, since the relationship is a function of your model. Try a range of values or functions to see what yields satisfactory results.