How to efficiently reduce GPU memory for knowledge distillation training

I am new to Pytorch, and I am working on some knowledge distillation task:
For instance, we have a large teacher network (pre-trained with imagenet) and a small student network.
We need to freeze the teacher network, and use its prediction to guidance the training of the small student network. In other words, teacher network is not trainable, but student network will be updated.

Currently, I simply used “with torch.no_grad()” on teacher network, but the training still consumed too much GPU memory. Because there is no need to backpropagate on the big teacher at all, I was wondering if there are some other ways to handle gradient and activations.

What can I do to further improve the GPU memory utility?

If you use torch.no_grad you are disabling the autograd recording of things for backpropagation, so you would have a hard time reducing memory use further (one idea could do (part of?) the teacher computation in fp16 instead of fp32), but you have to watch out for overflows.

In the student, you could use checkpointing (i.e. keeping less activations between forward and backward and recomputing them) to save memory.

Best regards

Thomas