Hi, I am working on a problem where I have two models, namely a Teacher model (A) and a student model (B).
The Teacher network is used to generate pseudo-labels for a set of unlabelled train set
X1. The pseudo-labels are used as ground truth to train the student network. The student network is updated based on the loss computed using the prediction from student network and the pseudo-labels.
Given labelled train set,
(X2, Y2), we use the updated student model B2 and perform forward propagation. Then a loss is computed between the prediction from B2 and Y2. Rather than updating the student network again, now I would like to update the Teacher model A. In other words, I would like to compute the gradient of loss in phase 2 w.r.t model A. However, in my implementation, when I do loss.backward() in phase 2, only the gradient w.r.t model B is computed. May I ask how I can compute the gradient w.r.t model A???
Below is a much more detailed explanation and I have also pasted my code at the very end.
Given a set of inputs, X1,
Y_A1 = A(X1) Y_B1 = B(X1) loss1 = crossEntropy(Y_A1, Y_B1) loss1.backward()
By calling loss1.backward() will computes the gradient of
loss1 w.r.t all the parameters in A and B. Then to update only model B,
optimB.step() B -> B2
is called, where optimB can be something like
optimB = torch.optim.SGD(B.parameters(), lr=0.01)
Then, given another set of input X2, and labels Y2
Y_B2 = B2(X2) loss2 = crossEntropy(Y_B2, Y2) loss2.backward()
I was under the assumption that by calling
loss2.backward(), gradient w.r.t. to parameters in model A will also be computed (second order derivative of
loss1 w.r.t model A) because
Y_B2 = B2(X2) B2 = B - lr * dloss1/dB To update A -> A2 A2 = A - lr * dloss2/dA A2 = A - lr * d(crossEntropy(Y_B2, Y2))/dA A2 = A - lr * d(crossEntropy(B2(X2), Y2))/dA
Therefore, in order to update A we need to compute
d(dloss1/db)/dA. However, in my implementation, when I called
loss2.backward(), only gradient of
dloss2/dB is computed.
I have pasted my code below:
X1 = torch.randn(5,2, dtype=torch.float32) X2 = torch.randn(5,2, dtype=torch.float32) Y2 = torch.randint(0,2, (5,1), dtype=torch.float32) # define linear model class linearModel(nn.Module): def __init__(self): super(linearModel, self).__init__() self.layer1 = nn.Linear(2, 2, bias=False) self.layer2 = nn.Linear(2, 1, bias=False) self.activation = nn.Tanh() def forward(self, x): x = self.layer1(x) x = self.activation(x) x = self.layer2(x) return torch.sigmoid(x) # define loss function: Cross Entropy def loss_fcn(pred, target): loss = -torch.mean(target*torch.log(pred) - (1-target)*torch.log(1-pred)) return loss modelA = linearModel() modelB = linearModel() optimA = torch.optim.SGD(modelA.parameters(), lr=0.01) optimB = torch.optim.SGD(modelB.parameters(), lr=0.01) # Phase 1 Y_A1 = modelA(X1) Y_B1 = modelB(X1) loss1 = loss_fcn(Y_B1, Y_A1) # Compute gradient w.r.t. B and update modelB loss1.backward() optimB.step() # Phase 2 Y_B2 = modelB(X2) # model B has been updated loss2 = loss_fcn(Y_B2, Y2) # Compute gradients w.r.t. model A and update model A loss2.backward() optimA.step()