How to compute the gradient of gradient if I have two models?

Hi, I am working on a problem where I have two models, namely a Teacher model (A) and a student model (B).

Phase 1

The Teacher network is used to generate pseudo-labels for a set of unlabelled train set X1. The pseudo-labels are used as ground truth to train the student network. The student network is updated based on the loss computed using the prediction from student network and the pseudo-labels.

Phase 2

Given labelled train set, (X2, Y2), we use the updated student model B2 and perform forward propagation. Then a loss is computed between the prediction from B2 and Y2. Rather than updating the student network again, now I would like to update the Teacher model A. In other words, I would like to compute the gradient of loss in phase 2 w.r.t model A. However, in my implementation, when I do loss.backward() in phase 2, only the gradient w.r.t model B is computed. May I ask how I can compute the gradient w.r.t model A???

Below is a much more detailed explanation and I have also pasted my code at the very end.

Given a set of inputs, X1,

Y_A1 = A(X1)
Y_B1 = B(X1)
loss1 = crossEntropy(Y_A1, Y_B1)

By calling loss1.backward() will computes the gradient of loss1 w.r.t all the parameters in A and B. Then to update only model B,

B -> B2

is called, where optimB can be something like

optimB = torch.optim.SGD(B.parameters(), lr=0.01)

Then, given another set of input X2, and labels Y2

Y_B2 = B2(X2)
loss2 = crossEntropy(Y_B2, Y2)

I was under the assumption that by calling loss2.backward(), gradient w.r.t. to parameters in model A will also be computed (second order derivative of loss1 w.r.t model A) because

Y_B2 = B2(X2)
B2 = B - lr * dloss1/dB

To update A -> A2
A2 = A - lr * dloss2/dA
A2 = A - lr * d(crossEntropy(Y_B2, Y2))/dA
A2 = A - lr * d(crossEntropy(B2(X2), Y2))/dA

Therefore, in order to update A we need to compute d(dloss1/db)/dA. However, in my implementation, when I called loss2.backward(), only gradient of dloss2/dB is computed.

I have pasted my code below:

X1 = torch.randn(5,2, dtype=torch.float32)
X2 = torch.randn(5,2, dtype=torch.float32)
Y2 = torch.randint(0,2, (5,1), dtype=torch.float32)

# define linear model
class linearModel(nn.Module):
    def __init__(self):
        super(linearModel, self).__init__()

        self.layer1 = nn.Linear(2, 2, bias=False)
        self.layer2 = nn.Linear(2, 1, bias=False)
        self.activation = nn.Tanh()

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        return torch.sigmoid(x)

# define loss function: Cross Entropy
def loss_fcn(pred, target):
     loss = -torch.mean(target*torch.log(pred) - (1-target)*torch.log(1-pred))
  return loss

modelA = linearModel()
modelB = linearModel() 

optimA = torch.optim.SGD(modelA.parameters(), lr=0.01)
optimB = torch.optim.SGD(modelB.parameters(), lr=0.01)

# Phase 1
Y_A1 = modelA(X1)
Y_B1 = modelB(X1)
loss1 = loss_fcn(Y_B1, Y_A1)

# Compute gradient w.r.t. B and update modelB 

# Phase 2
Y_B2 = modelB(X2)   # model B has been updated 
loss2 = loss_fcn(Y_B2, Y2)
# Compute gradients w.r.t. model A and update model A

However, modelA.layer1.weight.grad returns None.

Based on your code, your modelA()'s parameters didn’t get involved in the computation graph of calculating loss2, so when you print it gives you None

Hi, thanks for your reply. Is there a way to manually compute the gradients of loss2 w.r.t modelA’s parameters? Thank you

In this case, you should pass X2 to modelA():

Y_A2 = modelA(X2)
loss2 = loss_fnc(Y_A2, Y2)