Memory usage during backpropagation

I’m hoping to use a very large, pretrained, 80 layer transformer-based language model for a new task. My plan is to finetune the language model on my dataset and then use activations from the 40th layer of the language model to train simple linear probes on my new task.

For finetuning the model, I don’t have the memory capacity to train the whole model. I only have the ability to store gradients for one layer at a time on my GPU. So, I’m planning to set param.requires_grad as true for just the 40th layer of the model, and false for the rest. If I did this, would Pytorch still hold the gradients for the 41st-80th layers in memory for backpropagation? Or, would it be able to delete gradients as it backpropagates. I’m thinking it would be possible to calculate gradients for the 80th layer, backpropagate gradients for the 79th, delete gradients for the 80th, etc… so that gradients for only one transformer are stored at a time.

Hi @Marco_Conati

Setting requires_grad(False) detaches the weights from the computational graph, so it will not be used in backpropagation.

gradients for only one transformer are stored at a time

See torch.utils.checkpoint for this. Hope this helps!

Oh checkpoint looks very helpful. Thanks!

How does Pytorch calculate gradients for parameters earlier in the model if requires_grad is set to False at later layers?

For example, if I had a sequential model of linear layers and set requires_grad(False) at the final layer, how can earlier layers backpropogate gradients? Doesn’t it break the chain rule?

The gradients of earlier parameters do not depend on the gradients of later parameters.

To calculate gradients of earlier parameters, autograd only requires the gradients w.r.t inputs or intermediary values in the network. OTOH the values in .grad (which are not computed when requires_grad(False)) are the gradients w.r.t. the parameter; this is only used for updating the parameter.

Okay I see thanks. I think I understand how early gradients should still be calculable with requires_grad(False) down the line

I just tried to freeze all but a middle layer of my model, and I am getting a “One of the differentiated Tensors does not require grad” error. Do you have any insight into how to debug this? My initial searches online were very cryptic

I cannot reproduce this, can you share an example?