Hi, I am working on a problem where I have two models, namely a Teacher model (A) and a student model (B).
Phase 1
The Teacher network is used to generate pseudo-labels for a set of unlabelled train set X1
. The pseudo-labels are used as ground truth to train the student network. The student network is updated based on the loss computed using the prediction from student network and the pseudo-labels.
Phase 2
Given labelled train set, (X2, Y2)
, we use the updated student model B2 and perform forward propagation. Then a loss is computed between the prediction from B2 and Y2. Rather than updating the student network again, now I would like to update the Teacher model A. In other words, I would like to compute the gradient of loss in phase 2 w.r.t model A. However, in my implementation, when I do loss.backward() in phase 2, only the gradient w.r.t model B is computed. May I ask how I can compute the gradient w.r.t model A???
Below is a much more detailed explanation and I have also pasted my code at the very end.
Given a set of inputs, X1,
Y_A1 = A(X1)
Y_B1 = B(X1)
loss1 = crossEntropy(Y_A1, Y_B1)
loss1.backward()
By calling loss1.backward() will computes the gradient of loss1
w.r.t all the parameters in A and B. Then to update only model B,
optimB.step()
B -> B2
is called, where optimB can be something like
optimB = torch.optim.SGD(B.parameters(), lr=0.01)
Then, given another set of input X2, and labels Y2
Y_B2 = B2(X2)
loss2 = crossEntropy(Y_B2, Y2)
loss2.backward()
I was under the assumption that by calling loss2.backward()
, gradient w.r.t. to parameters in model A will also be computed (second order derivative of loss1
w.r.t model A) because
Y_B2 = B2(X2)
B2 = B - lr * dloss1/dB
To update A -> A2
A2 = A - lr * dloss2/dA
A2 = A - lr * d(crossEntropy(Y_B2, Y2))/dA
A2 = A - lr * d(crossEntropy(B2(X2), Y2))/dA
Therefore, in order to update A we need to compute d(dloss1/db)/dA
. However, in my implementation, when I called loss2.backward()
, only gradient of dloss2/dB
is computed.
I have pasted my code below:
X1 = torch.randn(5,2, dtype=torch.float32)
X2 = torch.randn(5,2, dtype=torch.float32)
Y2 = torch.randint(0,2, (5,1), dtype=torch.float32)
# define linear model
class linearModel(nn.Module):
def __init__(self):
super(linearModel, self).__init__()
self.layer1 = nn.Linear(2, 2, bias=False)
self.layer2 = nn.Linear(2, 1, bias=False)
self.activation = nn.Tanh()
def forward(self, x):
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return torch.sigmoid(x)
# define loss function: Cross Entropy
def loss_fcn(pred, target):
loss = -torch.mean(target*torch.log(pred) - (1-target)*torch.log(1-pred))
return loss
modelA = linearModel()
modelB = linearModel()
optimA = torch.optim.SGD(modelA.parameters(), lr=0.01)
optimB = torch.optim.SGD(modelB.parameters(), lr=0.01)
# Phase 1
Y_A1 = modelA(X1)
Y_B1 = modelB(X1)
loss1 = loss_fcn(Y_B1, Y_A1)
# Compute gradient w.r.t. B and update modelB
loss1.backward()
optimB.step()
# Phase 2
Y_B2 = modelB(X2) # model B has been updated
loss2 = loss_fcn(Y_B2, Y2)
# Compute gradients w.r.t. model A and update model A
loss2.backward()
optimA.step()
However, modelA.layer1.weight.grad
returns None
.