Hey everyone! Suppose I have a large network and computing the gradients for the entire network at once is prohibited by memory constraints. I could of course use checkpointing, but I was wondering whether it makes sense to only actually update a portion of the network in each iteration.
My questions would be as follows:
- Assuming I pick in each iteration a layer that I set to requires_grad = True, freezing all other layers, would that practically give me a reduction in memory usage, or doesn’t the backprop-graph work that way?
- What would be an efficient way of implementing a scheme where say in every iteration I unfreeze only X% of the layers at random. I could of course freeze the entire network and randomly activate some layers in each iteration, but iterating over the entire layer seems inefficient. What would be a neat way of doing so?
EDIT: This is the way I am currently doing it, but it does not seem to be any better in terms of efficiency. The epoch runtime is increased a lot. I overwrite the modules of the network using the following wrapper.
"""Takes in a layer and overwrites its forward pass to sample from a Bernoulli distribution whether the weight tensor requires grad or not."""
def __init__(self, original_layer, prob: float, **kwargs):
self.original = original_layer
self.prob = prob
def forward(self, x: torch.Tensor):
# Sample from Bernoulli distribution
self.original.weight.requires_grad_(random.random() < self.prob)
great thoughts. One thing to keep in mind is that memory is used for
- network parameters,
- activations (computation results in the forward pass), in particular those saved for the backward,
- optimizer state (e.g. momentum, second moments for Adam).
When you set some layer’s requires grad to false, you save the gradients and, if it is the first layer(s), the activations (happily this works well for finetuning only the last layers). So if you have parameters, gradients (=size of params), optimizer state (= 2* size of params), just cutting the gradients saves part of ~25% of that memory, less than 25% of the total if the activations are needed for backprop to lower layers.
Also, when you do random as you propose, you will actually peak at “full size” which is bad if you hope to avoid out of memory. Treating the all the network’s parameters as a large vector
Finally, there is the training performance (in terms of improvement per training step) that might also go down.
People have been doing various things to not keep stuff on the GPU (checkpointing, moving things to CPU, doing the optimizer stuff on the CPU and keep the state there, FSDP sending parameters and gradients over the network instead of keeping it on each GPU).
So in summary, it might be worthwhile to look a bit more at how much memory you need to save and how much memory you expect to save before digging too deep into the coding. That said, everyone always wants to do more with less memory, so your efforts are very much needed!
Thanks for the answer @tom! I am aware that memory is used to store a variety of things other than the actual gradient. Let me directly quote what you said below:
Also, when you do random as you propose, you will actually peak at “full size” which is bad if you hope to avoid out of memory.
Sure, in practice I would not do it randomly as illustrated above, that was just to convey the rough idea. In fact, I might even be more interested in obtaining efficiency speedups (and not in terms of improvement per training step, but rather as in walltime per train step) than memory savings.
When you set some layer’s requires grad to false, you save the gradients and, if it is the first layer(s), the activations (happily this works well for finetuning only the last layers).
Maybe I have to return to the backpropagation algorithm and write that out, but maybe I am missing something here. Say I have the first (i.e. closest to input) layer active (requiring_grad == True) and say all others frozen. To compute the gradients w.r.t. the first layer’s weights, I would need to store all activations of the forward pass, right? So technically, I would get a memory improvement only by not having to store the gradients of other parameters, but still I have to save all activations. And the efficiency speedup would be negligible, because I have to backpropagate through the entire graph anyways, so I’d basically get the other gradients almost for free?
Apart from that, I am wondering what happens in the backend that makes my current code so slow (measured in walltime per iteration, not improvement per step). So Bernoulli sampling and setting requires_grad should be fast, still I almost double the runtime compared to leaving all parameters active at all times. What am I missing here?
Appreciated, but not so sure about that yet!