Yea I understand if min(loss1, loss2) is loss1, loss.backward() will only backprop through model1 and vice versa. Please suggest a better or alternative way to solve this. I need to backprop min of loss1 and loss2 through both models.
In order for the gradients to flow through both branches of your
“min” function, you need a smooth version of min that transitions
from one branch to the other adequately gradually.
LogSumExp is a smooth version of the max function, and is
implemented in pytorch as torch.logsumexp(). You can turn it into
a smooth minimum by adding minus signs:
smooth_min = -torch.logsumexp (-arg, dim)
You can make it transition more or less abruptly by adding a
-torch.logsumexp (-q * arg, dim) / q
Note, if loss1 is much larger than loss2 (or if you transition
abruptly), then the gradients will flow only weakly back through model2 (which may well be what you want).
[Edit: However, upon further consideration, this scheme might
not do what you want. Let’s say, that by happenstance, model2
starts doing better than model1. Then gradients will flow more
weakly through model1, so it will learn more slowly. So model2
will get better still while model1 stagnates.
If model2 gets sufficiently better than model1, then model1
will effectively drop out of the picture, regards of whether, given
the chance and adequate training, it might have ended up doing
better than model2.]