I currently have a problem where I have two models, model_1 and model_2, and I would like to optimise one based on the performance of the other.

Both models are of different types.

I would like to optimise model_2 based on the performance of model_1. The loss function for model_1 is well defined (CrossEntropyLoss), but for model_2 it is not.

model_2 produces an output, which implicitly effects the performance of model_1. The output of model_2 is not directly given to model_1. Overall, it’s objective is to minimise the loss of model_1.

The problem I have is that I cannot back-propagate model_2 based on the loss of model_1 (being a scalar value). Is there a way to perform this operation?

Yes, you should be able to do something like this.

However, you say that “the output of model_2 is not directly given to model_1,” but that it “implicitly effects the performance of model_1.”

Unless you can tell us in concrete terms (perhaps with a fully-self-contained,
runnable example script) how model_2 is connected to model_1, we can’t
really offer advice as to how (or whether) you might backpropagate through
that connection.

I really appreciate the response. Apologies for leaving out some details, I thought it would simplify the question overall. I’m working in a Federated Learning setting, so in fact there are many models in play. I have model_1, model_2 and k client models {model_client_1, model_client_2, … model_client_k}.

The output of model_2 is used in a regularisation term for each client model’s loss function.

Each client model trains on it’s own local private dataset.

model_1 is the average (all weights averaged via parameter matching) of all the client models.

model_2 then needs to be optimised to ensure the performance of model_1 is minimised. However, as the output of model_2 is used as part of a regularisation term for the client models, backtracking isn’t possible and there is no direct computation map between model_1 and model_2.

Hopefully this makes sense. I can provide more details if still unclear

As I understand it, you train your client models – that is, you update the
parameters of your client models by taking optimization steps. And the
loss you use to train your client models consists of some normal sort of
loss term plus a regulator that depends on the output of model_2.

model_1 depends on (because its the average of) the trained client
models. It therefore depends on model_2 because the optimization
steps used to train the client models depend on model_2 (because model_2 was used for the regulator term in the client models’ training
losses).

The issue is that pytorch doesn’t offer functionality that lets you
backpropagate model_1’s performance back through the optimization
step that depended on model_2.

If I understand your use case properly, you would like to compute the
gradient of model_1’s “performance” with respect to model_2 (and
then use that gradient to optimize model_2 so that its output has the
desired influence on the training of the client models via the regulator
term).

I believe that you can do this under certain conditions (that hopefully
will be consistent with your use case).

First, we will compute the gradient with respect to model_2 for only
a single optimization step of the client models at a time (rather than
for a chain of optimization steps all at once).

Second, we will use plain-vanilla SGD to take the client-model optimization
steps. This is because the SGD optimization step:

param = param -= lr * param.grad

is easy to differentiate: d_param / d_param.grad = -lr. This manual
differentiation of the optimization step is the “glue” that will let us
backpropagate model_1’s performance back through the optimization
steps of the client models back to model_2’s parameters.

I’ve written a toy version of what I understand to be your use case and
use this differentiation of the SGD optimization step, together with
autograd (including computing the gradient of a gradient), to compute
the gradient of model_1 (after the client models have each undergone
one regulated optimization step) with respect to model_2.

Start by looking at the function fDGrad() in the example script. This
finite-difference estimate of model_1’s gradient with respect to model_2
can be understood as a concrete operational definition of what we mean
by that gradient. The function autogradGrad() then shows how to
compute that gradient using autograd (and our “manual” differentiation
of the SGD optimization step).

Here is the full script:

import torch
torch.__version__
torch.random.manual_seed (2024)
nDim = 4 # length of parameters, inputs, etc.
nClient = 3 # number of client models
doDouble = True # use double precision to improve finite-difference stability
print ('nDim: ', nDim)
print ('nClient: ', nClient)
print ('doDouble: ', doDouble)
# some helper functions
def applyParams (p, x): # apply (warped) model parameters to input x -- returns scalar "loss"
return torch.nn.functional.mse_loss (x, torch.nn.functional.tanhshrink (p))
def getParams (m):
return m.p.detach().clone()
def setParams (m, params, gradToNone = True):
with torch.no_grad():
m.p.copy_ (params)
if gradToNone:
m.p.grad = None
return
def applyAverage (mList, x): # apply average of models to x
pList = [m.p for m in mList]
pAvg = torch.stack (pList).mean (dim = 0)
return applyParams (pAvg, x)
class Mod (torch.nn.Module): # simple "model"
def __init__ (self):
super().__init__()
self.p = torch.nn.Parameter (torch.randn (nDim))
def forward (self, x):
return applyParams (self.p, x)
# instantiate models
# model_1 -- the average of the client models -- is never explicitly instantiated
modCList = [Mod() for _ in range (nClient)]
mod2 = Mod()
x = torch.randn (nDim) # use the same fixed input for all models
# lossReg will be the loss of a single client model together with its mod2 regulator
# lossOpt will be the loss of the average of the client models after they have each taken an optimization step
if doDouble:
mod2.double()
for m in modCList:
m.double()
x = x.double()
# randDir = randDir.double()
# save model parameters so that models can be reset to original state
modCParams = [getParams (m) for m in modCList]
mod2Params = getParams (mod2)
# instantiate client-model optimizers
lr = 0.1
modCOpts = [torch.optim.SGD (m.parameters(), lr = lr) for m in modCList]
# note that the following functions depend on global variables such as modCList, mod2, etc.
def lossOptParams (params): # helper function that computes lossOpt for given values of mod2 parameters
setParams (mod2, params)
for m, p in zip (modCList, modCParams):
setParams (m, p)
m2Reg = mod2 (x).detach()
for m in modCList:
loss = m (x)
lossReg = loss + m2Reg * (m.p**2).sum()
m.zero_grad()
lossReg.backward()
for o in modCOpts:
o.step()
lossOpt = applyAverage (modCList, x)
return lossOpt
# helper function that computes the two-sided finite-difference estimate of the
# directional derivative of lossOpt wrt mod2's parameters in the direction dDir
def fDDir (dDir):
eps = torch.finfo (dDir.dtype).eps # use different step sizes for different precisions
h = eps**(1.0 / 3.0) # finite-difference step for single precision
lossPlus = lossOptParams (mod2Params + h * dDir)
lossMinus = lossOptParams (mod2Params - h * dDir)
return (lossPlus - lossMinus) / (2.0 * h)
def fDGrad(): # function that computes finite-difference estimate gradient of lossOpt wrt mod2's parameters
dDirs = torch.eye (len (mod2Params), dtype = mod2Params.dtype)
result = torch.empty_like (mod2Params)
for i, dDir in enumerate (dDirs):
result[i] = fDDir (dDir)
return result
def autogradGrad(): # function that uses autograd to compute the gradient of lossOpt wrt mod2's parameters
# reset models
setParams (mod2, mod2Params)
for m, p in zip (modCList, modCParams):
setParams (m, p)
# take a single optimization step for each client model with m2Reg as requlator
m2Reg = mod2 (x)
lossReg = 0.0
for m in modCList:
lossReg += m (x) + m2Reg * (m.p**2).sum()
m.zero_grad()
# gradients are linear so perform a single backward for the sum of losses
lossReg.backward (retain_graph = True) # retain graph from mod2Reg to mod2 parameters
for o in modCOpts:
o.step()
# compute lossOpt and its gradient wrt client models
lossOpt = applyAverage (modCList, x)
modCPList = [p for m in modCList for p in m.parameters()]
lossOptGrad = torch.autograd.grad (lossOpt, modCPList)
# reset client models to their pre-optimization-step state
for m, p in zip (modCList, modCParams):
setParams (m, p)
# compute lossReg and its gradient wrt client models
lossReg = 0.0
for m in modCList:
lossReg += m (x) + m2Reg * (m.p**2).sum() # lossReg depends on both modCList and mod2
lossRegGrad = torch.autograd.grad (lossReg, modCPList, create_graph = True)
# compute the vector-jacobian-product of lossOptGrad with the derivative of lossRegGrad wrt mod2
# this -- with a factor of the optimizer lr -- gives the gradient lossOpt wrt mod2
# chaining these gradients together effects backpropagation through the SGD optimization step
lossOptGradJvp = -lr * torch.autograd.grad (lossRegGrad, mod2.parameters(), lossOptGrad)[0]
return lossOptGradJvp
mod2GradFD = fDGrad()
mod2GradAuto = autogradGrad()
# compare autograd result with finite-difference estimate
torch.set_printoptions (precision = 15)
print ('mod2GradFD = ...')
print (mod2GradFD)
print ('mod2GradAuto = ...')
print (mod2GradAuto)
print ('mod2GradAuto / mod2GradFD = ...')
print (mod2GradAuto / mod2GradFD)
print ('torch.allclose (mod2GradAuto, mod2GradFD, rtol = 1.e-7, atol = 1.e-12) =',
torch.allclose (mod2GradAuto, mod2GradFD, rtol = 1.e-7, atol = 1.e-12))

Let me add a potentially useful bit of information:

It appears that (at least some of) pytorch’s optimizers offer a feature that
may (or may not) do this for you.

As of pytorch version 1.13, SGD (and at least also Adam) offers a differentiableconstructor argument:

differentiable (bool,optional) – whether autograd should occur through the optimizer step in training. Otherwise, the step() function runs in a torch.no_grad() context. Setting to True can impair performance, so leave it False if you don’t intend to run autograd through this instance (default: False)

I’ve never tried using it and I don’t know how (or whether) it works, but it
might be worth looking into for your use case.