I’m working on some stuff related to what is described in the recent Meta Pseudo Labels paper (https://arxiv.org/abs/2003.10580). The process involves updating the weights of one network in a differentiable way and then using the same network (or probably a copy of it) with updated weights to compute another output.

The problem is, it seems that it’s not that straightforward to do this as I initially thought…

My current code (one training iteration) looks like this:

```
teacher_logits = teacher(unsupervised_batch)
student_logits = student(unsupervised_batch)
student_loss = kldiv(
torch.log_softmax(student_logits, dim=1),
torch.softmax(teacher_logits, dim=1)
)
student_loss.backward(create_graph=True)
teacher_optimizer.zero_grad()
# Here I want to obtain a new student but continue to track the gradients.
# Normally I would use p2.data -= lr * p1.grad.
# Updated student is just a copy of the original student.
for p1, p2 in zip(student.parameters(), updated_student.parameters()):
p2 -= lr * p1.grad # this expectedly fails as p2 is a leaf variable
student_logits = updated_student(student_data)
student_logits.squeeze_(dim=1)
student_loss = ce(student_logits, student_labels)
student_loss.backward()
teacher_optimizer.step()
```

I could make it work if I write the student in a functional fashion, such that `student(x, params)`

is a function that takes a batch of data points and the parameters. While this is doable for small networks, that would be super inconvenient to do so for larger models. I wish I could run the `updated_student`

with the updated weights in a differentiable fashion and stay with regular PyTorch nn.Module paradigm.