What is the best way of creating a queue in pytorch? Why does apending to a list cause Cuda to run out memory?

I’ve created a queue in my forward function as shown below:

que_a = [torch.rand((1,256,256)) ] * 3 # init queue with len of 3
que_b = [torch.rand((1,256,256)) ] * 3 # init queue with len of 3

i=0
while i<1e7:
    # Model does something
    que_a.pop(0); que_b.pop(0) # we pop the first element in the list
    que_a.append(var1); que_b.append(var2) # var 1 & 2 is pushed to the stack
    i+=1

Over time this caused Cuda to run out of memory, Why does this happen?

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 7.79 GiB total capacity; 4.50 GiB already allocated; 82.88 MiB free; 4.54 GiB reserved in total by PyTorch)

I cannot reproduce this issue using a simple CUDAtensor:

l = [torch.randn(1, device='cuda')]

for i in range(100):
    print('iter{}, mem allocated {}MB'.format(i, torch.cuda.memory_allocated()/1024**2))
    l.pop()
    l.append(torch.randn(1, device='cuda'))

and will get the same memory usage.

However, based on your comment:

# Model does something

I guess that var1 and var2 might be created by the model. If that’s the case, note that each tensor could still be attached to the computation graph, so that not only the tensor itself, but the entire graph would be stored in the lists, if you don’t detach() these tensors.

1 Like

Thank you very much @ptrblck, you are correct that the model creates Var1 and Var2. I’ve just tested your solution of detaching the tensor and it solved the issue.

I will write a detailed destription of the problem and the solution here shortly.

Following my previous reply:

I’ve created this simplified model that replicates the issue:

import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.que = [torch.rand((1,2560,2560), device=device)] * 3 # Init queue with len of 3
        self.fc = nn.Linear(2560, 2560)
        
    def forward(self, x):
        t=0
        while t<=1000:
            print("Itter count: ", t)
            x = self.fc(x)#.detach() # uncomment .detach to fixes the CUDA out of memory
            self.que.pop(0); self.que.append(x) # pop first element from the que and append x to the end of que
            t+=1

        return x

if __name__ == '__main__':
    model = Model().to(device)
    x=torch.rand((1,2560,2560), device=device)

    model(x)

I’ve also noticed that assigning the output of self.fc to a different variable also fixed the issue of Cuda running out of memory:

def forward(self, x):
    t=0
    while t<=1000:
        print("Itter count: ", t)
        y = self.fc(x)
        self.que.pop(0); self.que.append(y) # pop first element from the que and append y to the end of que
        t+=1

    return y

Thanks for the update!
I’m not sure, if the self.que usage is really the culprit here and think that you are increasing the memory by reusing x in the first example.

        while t<=1000:
            x = self.fc(x)

This code snippet will pass the output as the new input to the model, so that intermediate activations cannot be freed. The storage in self.que might be irrelevant here. Even if you append and pop it properly, the model itself would still hold references to the “old” input and PyTorch won’t free the graph.