Best way to proceed to find memory leaks during backprop?

Hello!

I’m currently having some CUDA “out of memory” errors during training after a certain number of iterations, and only when back-propagation is being used and precisely when trying to allocate new tensors in the backward methods of my autograd::Function subclasses. I’m using custom activation functions & layers for which I wrote custom CUDA code which is then wrapped in autograd::Function subclasses. I am 99% sure it is due to memory leaks and not problem size as I did something pretty similar in Python/PyTorch and had no issues at all, and on my C++ code it happens only after a certain number of iterations.

So the thing is that I’d like to know how one should go in general about identifying memory leaks in a libtorch program because right now a direct use of say valgrind doesn’t help me much and debugging/guessing is currently a real pain. Say for example since, from my understanding, a caching CUDA allocator is used by libtorch, I’d like to find (for example during debugging) whether some tensors are indeed being freed (as in their space is marked as “available” by the caching allocator) after they’re being dereferenced, because if they aren’t then that can help me indeed to localize a bug in my code causing those tensors to not be properly dereferenced.

Thank you!

Since you are assuming your custom CUDA extension is causing the trouble, I would probably try to come up with a minimal standalone code snippet using your module only (forward and backward) and check the allocations via print(torch.cuda.memory_summary()) in each iteration. If you are seeing a larger memory usage, then it would indeed point to your custom code. In this case, check first if you are storing any tensors, which would disallow PyTorch to free them. Also, I assume you are using libtorch or are you manually trying to manage the CUDA memory via raw malloc/free calls?

1 Like

Great idea, thanks a lot ptrblck! I’m working entirely in C++ though (both for training & inference), so no torch.cuda.memory_summary() right now but by looking at the source code of the Python API I found that I can get what I need by playing around with c10::cuda::CUDACachingAllocator::getDeviceStats and the c10::cuda::CUDACachingAllocator::StatType enum.

Found the culprit thanks to the memory stats. In a torch::autograd::Function, when I create a tensor during forward and store it in ctx->saved_data (for example ctx->saved_data["some_tensor"] = some_tensor; in forward, where some_tensor contains intermediary results I need to reuse later in the backward method), when we arrive at the backward method the tensor doesn’t get dereferenced even after backward finishes. However, if I add a ctx->saved_data["some_tensor"] = torch::Tensor(); at the end of backward (explicitly dereferencing some_tensor), I don’t have any memory leaks anymore.

@ptrblck can you please confirm whether or not it is expected behavior that tensors inside ctx do not get dereferenced even after the end of backward ?

I think the issue is caused by the wrong usage of directly string the tensors in ctx and if I’m not mistaken, we’ve seen a few similar issues in the past posted here.
Use ctx->save_for_backward and the memory leak should be gone.

1 Like

I confirm the problem is also gone when using ctx->save_for_backward instead of saving directly in the saved_data hash-map.

1 Like