How to prevent this memory leak?

Here’s a simple repro of my situation:

import torch
import torch.nn as nn

feature_size = 1000
window = torch.zeros((5, feature_size, 1), requires_grad=True).cuda()
input_encoder = nn.Linear(3 * 224 * 224, feature_size).cuda()
step_count = 0

while True:
    # Create new element
    new_element = torch.zeros((1, 3, 224, 224), requires_grad=True).cuda()
    new_element = input_encoder(new_element.reshape(1, -1))

    # Remove last element and add the new one
    window = window[:-1]
    window = torch.cat((new_element.unsqueeze(-1), window), dim=0)

    step_count += 1
    print("Step: {}".format(step_count))

On my Titan Xp (12GB memory) this runs out of memory after just over 9300 steps. I tested initially using PyTorch 1.0.0 with cuda 8, but I repro’d it also with PyTorch 1.1.0 and cuda 10. In both cases using Python 3.5 on Windows 10.

Now ideally, what’s happening is that a new element gets added to window, and one gets removed, and all of the resources associated with the old one get cleared. I’m not sure how best to achieve freeing the old one though. Neither gc.collect() nor torch.cuda.empty_cache() have fixed the problem.

If I define memReport:

def memReport():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())

and run it in the loop, the result is always:

<class 'torch.Tensor'> torch.Size([1, 1000])
<class 'torch.Tensor'> torch.Size([5, 1000, 1])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000, 150528])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])

Anyone have any advice or pointers on freeing resources and preventing this memory leak?

Thanks!

Can you perhaps post a more complete sample. This one doesn’t have backward(...) which actually removes the graph and frees up memory. I’m assuming you omitted it to post pseudocode. From the example:

    window = torch.cat((new_element.unsqueeze(-1), window), dim=0)

The above particular line possibly keeps things in memory. If it’s okay for you, try:

    window = torch.cat((new_element.detach().unsqueeze(-1), window), dim=0)

Very interesting. Adding:

loss = window.mean()
loss.backward(retain_graph=True)

to the loop fixes the leak. Can you (or someone) explain what exactly is happening here?

It’s true that this is a min repro of a bigger codebase (though I did not intend it as pseudocode), but my actual code also does not run backward terribly frequently at the moment, by design, and it would be useful if there was a way I could prevent this memory accumulation between backwards in some way? (Or if it’s unavoidable I’d like to understand why, at least.)

What would the ramifications be of doing backward on some, essentially, fake loss, and then later (when I’m ready) actually doing the full zero_grad/backward/step with my real loss? What is backward(retain_graph=True) actually doing here to free my memory? (I initially assumed retain_graph would mean there was no benefit to running it, but I seem to have been wrong about that.) Is there some less hacky way of doing this?

Regarding your other question: With the detach it does not run out of memory, but it does break the graph chain I’m trying to construct, I’m pretty sure. But anyway the backward seems to have been the issue.

Thanks so much for your insight, I definitely assumed that backward only freed the tensors created by the previous backward; I did not know it also impacted tensors created during the forward pass.

See if one of the following threads suits your case.

Oh, I was wrong. My apologies. Adding the backward just slowed the network down so much it just looks like the memory usage was flat, and I was just not paying enough attention - by step 4k the memory was half consumed (doing a backward every 1000).

My issue is basically that I know what I need to remove, I just don’t have any way to sever it completely from the graph as far as I can tell. The links are interesting but are really not quite what I’m looking for. And window[-1] = window[-1].detach() doesn’t do much as far as I can tell, when window is a tensor.

I think I’ve decided what I want basically just can’t be done for some reason. Once an element is put into a tensor, it seems to live in some form in the graph that I can’t purge even when the element is no longer in the tensor.

This would have been convenient for other reasons in my graph, but I think I’ll just have to make do with a list and torch.stack.

Anyway, thanks for the help. :slight_smile: