I am trying to implement the Meta-Pseudo Label paper (https://arxiv.org/pdf/2003.10580v2.pdf) where a teacher model is updated based on a student model’s performance on a validation set. It would look something like this:

modelA_optim = optim.Adam(modelA.parameters(), lr=0.01)
modelB_optim = optim.Adam(modelB.parameters(), lr=0.01)
pseudo_labels = modelA(X_train)
# Train model B on pseudo labels
logits = modelB(X_train)
loss1 = CrossEntropy(logits, pseudo_labels)
loss1.backward()
modelB_optim.step()
# Update model A with model B's loss on a validation set
logits = modelB(X_val)
loss2 = CrossEntropy(logits, y_val)
loss2.backward()
modelA_optim.step()

This in iteself wouldn’t update model A’s weights as the gradient cannot backpropate all the way through model B and eventually to model A through the graient of loss1.

I have tried using the higher library to make model B a functional one, thereby being able to backpropate through its one updat (becuase of loss1, model B’s weights_t will now be weights_t+1). However, torch.autograd.grad(loss, modelA.parameters()) yields zero gradients for model A. I suspect this is because the backward graph is broken at the loss1 point of connection between model A and model B. Is there a way to address this?!

I think the problem lies in the optimization step. As you can see here : SGD step the step function is operating on p.data instead p (parameter). So the computational graph is not keeping track of these operations ( Equation 1 of the paper for instance).
You may also need to set create_graph to True in your backward function in order to compute the second order derivatives of loss1 since you will need it at some point to backpropagate the gradients of loss2.

I agree that the optimization is where the problem is in the code snippet above (even after setting create_graph=True). I looked in to the higher library (https://github.com/facebookresearch/higher) which implements a differentiable optimizer. This would allow you to track the updates to the parameters through the many SGD updates. The code would look like this:

with higher.innerloop_ctx(modelB, modelB_optim, copy_initial_weights=False) as (fmodel, diffopt):
pseudo_labels = modelA(X_train)
# Train model B on pseudo labels
logits = fmodel(X_train)
loss1 = CrossEntropy(logits, pseudo_labels)
diffopt.step(loss1)
# Update model A with model B's loss on a validation set
logits = fmodel(X_val)
loss2 = CrossEntropy(logits, y_val)
grad = torch.autograd.grad(loss2, modelA.parameters())

The grad is all zero. I suspect this is because higher takes a snapshot of model B only (to make it functional), the gradients cannot be backprop-ed all the way to model A.

If there was a way to extend the graph for fmodel leading all the way back to model A, this would probably work.

My bad. It was supposed to be fmodel in the original change itself. However, even after using fmodel, you’ll notice that the torch.autograd.grad return all zeros. And this I think is because the backward graph doesn’t extend to model A. Trying to find out how to do this so that the grads are not zero!

You’re absolutely right! This depends on the model, mostly the depth of it. I was using Resnets as teachers and students. It appears that the second gradients are so small that there are effectively zero.

Thank you very much @Ouasfi for taking the time out and running an example. As a sidenote, the grad is zero if you interchange your teacher and student model definitions. Time to get deeper into the gradient norms I guess!

I’m also trying to implement meta pseudo labels - this is the training loop I have, but not sure if teacher_loss.backward() is updating the parameters in the right way.

Any advice would be much appreciated!

for i in range(epochs):
print(i)
with higher.innerloop_ctx(student, student_optim, copy_initial_weights=False) as (fstudnet, diffopt):
pseudo_labels = teacher(x_train) # labels from teacher
logits = fstudnet(x_train)
student_loss = student_loss_func(logits, pseudo_labels)
diffopt.step(student_loss) # zero grad done implicitly
logits = fstudnet(x_val)
teacher_loss = teacher_loss_func(logits, y_val)
# grad = torch.autograd.grad(teacher_loss, teacher.parameters())
# not sure how to pass grad explicitly to model parameters
teacher_loss.backward()
teacher_optim.step()
teacher_optim.zero_grad()