Expressing nn.Module parameter as a function of another tensor

I’m working on some stuff related to what is described in the recent Meta Pseudo Labels paper ( The process involves updating the weights of one network in a differentiable way and then using the same network (or probably a copy of it) with updated weights to compute another output.

The problem is, it seems that it’s not that straightforward to do this as I initially thought…

My current code (one training iteration) looks like this:

teacher_logits = teacher(unsupervised_batch)
student_logits = student(unsupervised_batch)
student_loss = kldiv(
    torch.log_softmax(student_logits, dim=1),
    torch.softmax(teacher_logits, dim=1)

# Here I want to obtain a new student but continue to track the gradients.
# Normally I would use -= lr * p1.grad.
# Updated student is just a copy of the original student.
for p1, p2 in zip(student.parameters(), updated_student.parameters()):
    p2 -= lr * p1.grad    # this expectedly fails as p2 is a leaf variable

student_logits = updated_student(student_data)
student_loss = ce(student_logits, student_labels)


I could make it work if I write the student in a functional fashion, such that student(x, params) is a function that takes a batch of data points and the parameters. While this is doable for small networks, that would be super inconvenient to do so for larger models. I wish I could run the updated_student with the updated weights in a differentiable fashion and stay with regular PyTorch nn.Module paradigm.