I’m working on some stuff related to what is described in the recent Meta Pseudo Labels paper (https://arxiv.org/abs/2003.10580). The process involves updating the weights of one network in a differentiable way and then using the same network (or probably a copy of it) with updated weights to compute another output.
The problem is, it seems that it’s not that straightforward to do this as I initially thought…
My current code (one training iteration) looks like this:
teacher_logits = teacher(unsupervised_batch) student_logits = student(unsupervised_batch) student_loss = kldiv( torch.log_softmax(student_logits, dim=1), torch.softmax(teacher_logits, dim=1) ) student_loss.backward(create_graph=True) teacher_optimizer.zero_grad() # Here I want to obtain a new student but continue to track the gradients. # Normally I would use p2.data -= lr * p1.grad. # Updated student is just a copy of the original student. for p1, p2 in zip(student.parameters(), updated_student.parameters()): p2 -= lr * p1.grad # this expectedly fails as p2 is a leaf variable student_logits = updated_student(student_data) student_logits.squeeze_(dim=1) student_loss = ce(student_logits, student_labels) student_loss.backward() teacher_optimizer.step()
I could make it work if I write the student in a functional fashion, such that
student(x, params) is a function that takes a batch of data points and the parameters. While this is doable for small networks, that would be super inconvenient to do so for larger models. I wish I could run the
updated_student with the updated weights in a differentiable fashion and stay with regular PyTorch nn.Module paradigm.