Training two deep learning models simultaneously

So I have to train two models simultaneously, where the input of the 2nd model is the output of the 1st model.
And the output of my 2nd model is the grad of 1st model’s output wrt its input

Previously I was training these 2 models independently

  1. getting the outputs from the 1st model, then calculating its grad wrt input,
  2. then feeding these outputs of 1st model to my 2nd model, and checking if the produced grad from 2nd model is equal to the original grads computed at the 1st step, using MSE loss.

Now I want to train all this simultaneously. And I cant think of how

I was calculating the grads like this

input.requires_grad = True 
opx = model(input)  
dx1 =

Help will be much appreciated.

You could just do

opt = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()), lr=3e-4)
    out1 = model1(input)
    out2 = model2(out1)
    loss = loss(out2, target)

(Or even stick the two models in a sequential if you want…)

This will cause loss.backward() to backpropagate through both model2 and model1’s computation.

Best regards


1 Like

Thanks, Thomas,
Sorry for another question, but the problem is that the target for out2 is the grad of out1 wrt its input.
so how to proceed then
(I know you are thinking why am I am doing this weird thing, right now I can’t share, but if successful to implement, I would be glad to. )

Ah, sorry, I got that wrong here.
To not give the next bad advice: So do you have a separate loss function for model1’s ouput and it is only model2 that should be optimized to minimize the MSE between gradients and its outputs? (Vaguely rings a bell that people did something like that a years ago, but I cannot find any references now.)

Best regards


1 Like

Yes, I have separate loss functions.

So do you have a separate loss function for model1’s ouput and it is only model2 that should be optimized to minimize the MSE between gradients and its outputs?

Yes, that’s somewhat I want to do, but want to optimize both the models simultaneously.

Then I would try something along the lines that you suggested.

Usually I advise against using .backward for anything except model parameters (and use torch.autograd.grad instead, but this seems to be the exception to the rule because we don’t want to duplicate the work of doing grad and backward.

So here I would do somethign like

input.requires_grad_()  # I prefer this to avoid requeires_grad = True
# if you ever re-use inputs
if input.grad is not None:
    with torch.no_grad():
out1 = model1(input)
loss1 = loss_fn1(out1, target)
loss1.backward()  # populates grad of model1.parameters() and input

out2 = model2(input.detach())
loss2 = mse_loss(out2, input.grad)   # no need to use .data it will not require grad unless you use create_graph=True (which you should not)

loss2.backward()  # will populate model2.parameters() gradients  

Then you’d want 1 or 2 optimizers…

Now if you indeed want the gradient w.r.t. out1.sum() rather than loss1, you would want to use gr_in, = torch.autograd.grad(out1.sum(), input, retain_graph=True) before loss1.backward.

Is this approximately what you wanted?

Best regards


1 Like

Thanks, I think this is what I was searching for
Still, I have some silly doubts (I am still learning)

  1. For the optimizer part, I am using 2 optimizers, so should I use optimizer.step() and optimizer.zero_grad() after calling loss.backward() for each model or do it in the end.

  2. When we do this input.requires_grad_() what’s the difference between input.requires_grad = True

  3. like suppose in the 2nd epoch, if the input.grad is still on, so if we find grad again wont it be grad of grads .

  4. Can you explain more why we are doing this

if input.grad is not None:
    with torch.no_grad():

Thanks, I am very grateful.

I don’t think it matters much as long as for each model you keep the
zero_gradsbackwardstep order.

Mainly the looks and that it throws an error when you misspell requires_grad_ rather than assigning some useless field.

The typical thing is that you get (batches of) your inputs from a dataloader. In this case the tensors with the inputs are re-created in every epoch, so no need to worry about it.

I don’t think you need this in most cases (you could add a print statement to see if this ever is the case).
If you were to re-use inputs for whatever reason, this does the equivalent of optim.zero_grad() but for the input. (Because it isn’t in any optimizer.)

Best regards


1 Like