Hi m!
It’s not entirely clear to me what you want to do here.
I assume that you want to optimize the parameters of model1
so as to
minimize loss1
, thus, in part, minimizing shared_loss
, and optimize
the parameters of model2
so as to minimize loss2
, which includes
-shared_loss
, therefore implying maximizing shared_loss
.
That is, leaving aside loss_fn1
and loss_fn2
as they don’t change
the basic issue, my understanding is that you want model1
to minimize
shared_loss
while model2
maximizes shared_loss
.
If so, the cleanest approach will be to use a single backward pass and
a single optimizer, but to flip the gradient of shared_loss
when you
backpropagate through model2
. So something like:
opt = torch.optim.SGD (list (model1.parameters()) + list (model2.parameters()), lr = learning_rate)
...
m1 = model1 (x)
m2 = model2 (x)
loss = shared_loss (m1, GradientReversalFunction.apply (m2))
opt.zero_grad()
loss.backward()
opt.step()
Here GradientReversalFunction
is a custom autograd function that passes
the value of m2
through unchanged during the forward pass, but changes
the sign of the gradient of m2
during the backward pass.
This thread discusses a couple of ways to effect such a gradient reversal:
Best.
K. Frank