It seems that pytorch.utils.checkpoint doesn't save GPU memory?

I need to show that some technique called gradient checkpointing can really save GPU memory usage during backward propagation. When I see the result using pytorch_memlab there are two columns on the left showing active_bytes and reserved_bytes. In my testing, while active bytes read 3.83G, the reserved bytes read 9.35G. So why does PyTorch still reserve that much GPU memory?

The reserved memory would refer to the cache, which PyTorch can reuse for new allocations.
Based on the stats you are seeing it seems that some peak memory usage might have been larger, but PyTorch is able to release it and push it back to the cache, so that it can reuse it the next time it needs memory without allocating new device memory via cudaMalloc.

1 Like

So if I continue the reasoning with my example: Say I have a 4G GPU card, I still cannot train the model even though the active memory usage is 3.83G, since those cached memories up to 9.35G are needed for that 3.83G?

I’m not familiar with your use case and don’t know how you were using the checkpointing util. or which operations are using the cache etc.
Generally, you are trading compute for memory via checkpointing as the intermediate activations will not be stored, but recomputed.

Yes, but I cannot just say that to people… I need to graph it so people will know “Well, now I’m convinced that PyTorch really did that.” So now my approach is to try to use this pytorch_memlab package.

OK, cool. So based on this some of my operations might cause that cached being used?

Sorry, I don’t understand your current concerns or issues, as the only thing you are describing is that checkpointing apparently doesn’t work as you are seeing memory in the internal cache without giving any details.

Every memory allocation will land in the cache when the reference to the tensor is deleted to avoid the synchronizing and expensive cudaMalloc calls.

1 Like

Sorry, I don’t understand your current concerns or issues

My concern: checkpoint(function, *args) should have the same requires_grad as function(*args) · Issue #7617 · pytorch/pytorch · GitHub

Anyway, after I re-read your first comment on this thread, it helps.