Memory leak when appending tensors to a list

Here is the snippet:

    for ii,line in enumerate(lines):

        if line['evidence'][-1][-1][-1]  != None:
            #feats, labels = tf_idf_claim(line)
            feats, ev_sent, _ = fasttext_claim(line,ft_model,db,params)
            labels = ind2indicator(ev_sent,feats.shape[-1])
            all_labels.append(labels.numpy())

            pred_labels, scores = model(feats.unsqueeze_(0).transpose_(1,2).cuda(device_id))
            all_scores[ii] = scores

I see that the gpu errors out OOM after a while. This is running on Tesla K80.
If I change the last line to all_scores[ii] = scores.detach().cpu().numpy(), it gets fixed.
This is only process running on the gpu core.

I’m not sure how scores was calculated, but it could still hold a reference to the computation graph.
If that’s the case, you are storing the whole computation graph in your list in each iteration, which eventually fills up your GPU memory.
Detaching the tensor should fix this issue as you’ve already mentioned.

scores come from softmax outputs. The scores are Tensors of shape [1,10,100]. Does it mean any tensor calculated by the model holds a reference to the whole graph?
I am new to pytorch, so my question might be dumb. Thanks for your help.

In this case the softmax output needs the computation graph to be able to calculate the gradients in the backward pass.
You can check it by printing the grad_fn:

print(scores.grad_fn)
# Should return something like
<SoftmaxBackward at 0x7f371721a668>

If you store something from your model (for debugging purpose) and don’t need to calculate gradients with it anymore, I would recommend to call detach on it as it won’t have any effects if the tensor is already detached.

That makes sense to me now. Thanks for the detailed answer.

Hi, I have an question on this sentence:

you are storing the whole computation graph

Interpreting this sentence seems not so intuitive, simply by looking at code, the list
only stores the variable scores, ? Does it turns out the information(grad_fn and some other stuff need for autograd) of every tensor that is in the computation would be store in the all_scores[ii] ?
Thanks !

Yes, that’s correct. scores is attached to a computation graph via its grad_fn and if you store the scores tensor in e.g. a list the entire graph with all intermediate tensors needed to compute the gradients in this computation graph will be stored with it.

1 Like

So does it mean if I have gpu device that holds 12 GB and my model takes 6.1G of gpu memory during training, If I stores the output of model to a list, in this case score , Do I get OOM immediately since the total memory consumption is 6.1 * 2 = 12.2 GB ?

No, since the memory is not only used by the last computation graph, but also the model parameters and potentially other data on the GPU.

1 Like

Ok, Is there any way to get or calculate the memory usage of computation graph ?
By the way does torch destroy all computation(free memory taken by graph) when backward() is called ,and rebuild the graph when next forward() is called ?

An easy way would be to check the allocated memory while the output of the forward pass (and thus the computation graph) is alive, delete the output, and re-check the allocated memory. The delta would then correspond to the the memory usage of the computation graph.

Yes, that’s the case in the default setup. PyTorch will not be able to free the computation graph and the intermediate tensors after the backward call if you are storing references to the computation graph as already described. Also, using retain_graph=True in backward will not free the graph since you are explicitly keeping it alive.

An easy way would be to check the allocated memory while the output of the forward pass (and thus the computation graph) is alive, delete the output, and re-check the allocated memory. The delta would then correspond to the the memory usage of the computation graph.

This is a great idea ! I would try it if I need it in the future.

PyTorch will not be able to free the computation graph and the intermediate tensors after the backward call if you are storing references to the computation graph as already described
Let’s say if I have a graph looks like this :

while 1 : 
    c = a  + b
    e = c + d 
    g = e + f
    g.backward()

Assume that a, b, c, d, e, f, g all requires grad. So in the normal situation, Inside while loop
the backward() is called and every thing in the graph is destroyed as describe above. But now if I some how make c by setting retain_graph=True, I ave two question:

  1. During the backward is c is an only tensor would be kept in the graph, or there is any other tensor would stay alive in graph due to c is not erased ?
  2. To my understanding, if the loop goes on, since every time the graph is not cleaned complete , OOM would occur at some point ?
  1. None of the mentioned tensors will be deleted in the backward call, since you have explicitly created references to them by creating the variables. The backward call will only free tensors in the computation graph if no object holds a reference to them anymore.

  2. Also no, since you are overriding all tensors explicitly and there are no intermediates which wouldn’t be freed, but you can also run your example and check the actual memory usage to verify it.

1 Like

I don’t completely understand this sentence “holds a reference to”, To my knowledge, if I write something like this :

a = 0
a = 1

In this case,Can I say a does not hold a reference to object 0 and now holding the reference to object 1, to be easier to understand,can I say a now points to object 1 ? Correct me if I am wrong.

Another case confuses me as well, by looking at the comment from Memory leak issue from yolox , they claimed that they solved memory leak issue by using detach()

def update(self, values=None, **kwargs):
       if values is None:
           values = {}
       values.update(kwargs)
       for k, v in values.items():
           if isinstance(v, torch.Tensor):
               v = v.detach()
           self[k].update(v)

In this case why detach() should be used, if the keys in dictionary are fixed, even I don’t use detach(), every iteration or some period time I store the value into self[k] from the tensor that is still in the computation graph, I think this would increase the memory usage for sure, however doesn’t the "amount of extra memory usage " should be zero ?

self["loss"] = loss_t1
self["loss"] = loss_t2 #the loss_t1 is destroyed
....

In this case, why would memory usage is increasing ? Thanks

class A:
   def __init__(self, my_pony):
      self.my_pony = my_pony

b = something large
a = A(b)
b = None

“something large” won’t be garbage collected as a “holds a reference to it”

Sorry for interrupting you, I would like to know if it’s possible for you to reply the question that I post above to help me clear my confusion. Thanks !

@Omroth already explained the reference holding via a small example, but maybe this code would also help you as you could play around with it to understand a bit more about the memory allocation:

import torch
import torch.nn as nn

print(torch.cuda.memory_allocated()/1024**2)
# 0.0

device = 'cuda'
x = torch.randn(1024, 1024, device=device)
# expected: 4MB = 1024 * 1024 * 4 / 1024**2
print(torch.cuda.memory_allocated()/1024**2)
# 4.0

model = nn.Sequential(
    nn.Linear(1024, 1024, bias=False),
    nn.Linear(1024, 1024, bias=False),
    nn.Linear(1024, 1024, bias=False),
    nn.Linear(1024, 1024, bias=False)
).to(device)
# expected: 20MB = 4MB (previous) + 16MB via 1024 * 1024 * 4 * 4 / 1024**2
print(torch.cuda.memory_allocated()/1024**2)
# 20.0

for _ in range(3):
    out = model(x)
    print(torch.cuda.memory_allocated()/1024**2)

# expected: 36MB = 20MB (previous) + 16MB for intermediates
# 36.0
# 36.0
# 36.0

del out
# expected: 20MB (as previously)
print(torch.cuda.memory_allocated()/1024**2)
# 20.0

for _ in range(3):
    out = model(x)
    print(torch.cuda.memory_allocated()/1024**2)
    out.mean().backward()
    print(torch.cuda.memory_allocated()/1024**2)

# 36.0 # 1st forward: 20MB (base) + 16 (intermediates)
# 40.0 # 1st backward: intermediates deleted, but 20MB (base) + 16MB for gradients + 4MB for `out`
# 52.0 # 2nd forward: 20MB (base) + 16 (grad) + 16MB (intermediates)
# 40.0 # 2nd backward: same as 1st backward
# 52.0 # 3rd forward: same as 2nd forward
# 40.0 # 3rd backward: same as 1st backward

del out
print(torch.cuda.memory_allocated()/1024**2)
# 36.0 # 20mb (base) + 16MB (grad)

model.zero_grad(set_to_none=True)
print(torch.cuda.memory_allocated()/1024**2)
# 20.0 # 20MB (base)

l = 0
for _ in range(3):
    out = model(x)
    l += out.mean()
    print(torch.cuda.memory_allocated()/1024**2)

# 36.00048828125
# 48.00048828125
# 60.00048828125

l = l.backward()
print(torch.cuda.memory_allocated()/1024**2)
# 40.00048828125

model.zero_grad(set_to_none=True)
print(torch.cuda.memory_allocated()/1024**2)
# 24.00048828125

del out
print(torch.cuda.memory_allocated()/1024**2)
# 20.00048828125

Also note that all these issues are strictly speaking not memory leaks even though the increase in memory usage is often called a leak by users. As you can see in my example, the increase in memory usage is expected and deleting the right tensors will also free the memory and make it reusable, which means that your script it not leaking memory (this memory would be lost and you won’t be able to recover it).

1 Like

Wow, this is really awesome and informative. Thanks !
I have an question on this part of code,

for _ in range(3):
    print("Before forwarding: {} MB.".format(torch.cuda.memory_allocated()/1024**2))
    out = model(x)
    print("After forwarding: {} MB".format(torch.cuda.memory_allocated()/1024**2))
    print("The id of out: {}".format(id(out)))
    del out

The output is :

Before forwarding: 20.0 MB.
After forwarding: 36.0 MB
The id of out: 140380622503616
Before forwarding: 20.0 MB.
After forwarding: 36.0 MB
The id of out: 140380622499328
Before forwarding: 20.0 MB.
After forwarding: 36.0 MB
The id of out: 140380622502848

In this case when del out is performed, the memory is free as expected, but without del out, the memory allocation would stay 36.0 MB after first iteration, I would like to know why this happen ? Wouldn’t the out gets collected at end of iteration and there’s no more object hold the reference to computation graph and the memory taken by graph should be free? Thanks !

No, del would still be alive until it is overwritten with the new output of the model or when do you think it would be deleted?
Simply because the loop ends doesn’t mean Python frees variables:

for i in range(2):
    print("iter {}".format(i))
    if 'out' in locals() or 'out' in globals():
        print("out is alive in iter {}".format(i))
    out = 1
    
# iter 0
# iter 1
# out is alive in iter 1

Based on your assumption "out is alive in iter 1" should not have been printed since somehow out was automatically deleted?

2 Likes