I am new to Pytorch, and I am working on some knowledge distillation task:
For instance, we have a large teacher network (pre-trained with imagenet) and a small student network.
We need to freeze the teacher network, and use its prediction to guidance the training of the small student network. In other words, teacher network is not trainable, but student network will be updated.
Currently, I simply used “with torch.no_grad()” on teacher network, but the training still consumed too much GPU memory. Because there is no need to backpropagate on the big teacher at all, I was wondering if there are some other ways to handle gradient and activations.
What can I do to further improve the GPU memory utility?