How to calculate the gradient based on another gradient?

Suppose we have two networks with parameters f and g respectively.

And they all have the same input x such that y1, y2=f(x) [multi-task network with two outputs] and z=g(x).

What I want to do is this: Assume y1=f1(x), y2=f2(x), to update f we have two losses in which one is based on the prediction from g(x), after one gradient update, we want to update g which based on the performance of f1(x):

f’ = f + \alpha * Grad(L(f1(x), y1_gt) + L(f2(x), g(x)))

g‘ = g + \beta * Grad(L(f1’(x), y1_gt))). # use the updated f’

\alpha, \beta are learning rates.

The current issue is: I don’t know how to implement the grad(L(f1’(x), y1_gt))).

My current code is (proved to be wrong):

# define my model1 with two outputs
class model1(...):
    ...
    def forward(x):
    ...
    return y1, y2

# define my model2 with one output    
class model2(...):
    ...
    def forward(x):
    ...
    return z

def model_fit(pred, gt):
        return [cross_entropy_loss(pred, gt)]

# define optmizers:
optimizer1 = optim.SGD(model1.parameters(), lr=1e-5)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-5)

# run one iteration as a simple example:
y1, y2 = model1(train_data)  # y1,y2=f(x)
z = model2(train_data)   # z=g(x)

optimizer1.zero_grad()
optimizer2.zero_grad()

train_loss1 = model_fit(y1, y1_gt) # compute loss1 L(f1(x), y1_gt)
train_loss2 = model_fit(y2, z) # compute loss2  L(f2(x), g(x))

(train_loss1+train_loss2).backward(create_graph=True) # Grad(L(f1(x), y1_gt) + L(f2(x), g(x))
optimizer1.step() # f’ = f + \alpha * Grad(L(f1(x), y1_gt) + L(f2(x), g(x)))

# 2nd forward pass based on the updated network 1.
y1,y2 = model1(train_data)  # y1,y2 = f'(x)
train_loss1 = model1.model_fit(y1, y1_gt)  # L(f1’(x), y1_gt))
train_loss1.backward()  # Grad(L(f1’(x), y1_gt)))
optimizer2.step()  # g‘ = g + \beta * Grad(L(f1’(x), y1_gt)))

However, the performance didn’t go as expected. It only calculates the gradient on train_loss2 in the first step() and retains the same gradient on the second step().

Update: I think I have understood what my code is doing: in the first backward function: it calculates the gradient w.r.t. both model1 and model2, the optimizer1.step() updates the gradients in model1 (Grad(L(f1(x), y1_gt) + L(f2(x), g(x))) w.r.t. f). However, in the second forward pass, the model2 recreates the computational graph and it cannot retain the previous graph. Thus, the gradient of train_loss1 w.r.t. model2 would be simply zeros. Since I put a create_graph=True in the first gradient update, the second backward pass will cumulate the gradients, thus the same as train_loss2 gradient w.r.t. model2 in the first gradient update (Grad(L(f2(x), g(x)) w.r.t. g, (accumulates grad(L(f1(x), y1_gt) w.r.t. g which is zero) ).

Now the problem becomes: how to retain the graph in the second forward pass so that we can compute the higher order gradients.

Any comments are welcomed. Thanks!

Hi, @lorenmt have you solved this problem?