LSTM on CPU won't release memory when all refs deleted

Hey,

Merely instantiating a bunch of LSTMs on a CPU device seems to allocate memory in such a way that it’s never released, even after gc.collect(). The same code run on the GPU releases the memory after a torch.cuda.empty_cache(). I haven’t been able to find any equivalent of empty_cache() for the CPU.

Is this expected behavior? My actual use-case involves training several models at once on CPU cores in a Kubernetes deployment, and involving LSTMs in any way fills memory until the Kubernetes OOM killer evicts the pod. The models themselves are quite small (and if I load a trained model in, they take up very little memory), but all memory temporarily used during training stays filled once training is done.

Code:

import torch
import torch.nn as nn
import gc

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
throwaway = torch.ones((1,1)).to(device) # load CUDA context

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, dropout_perc):
        super().__init__()
        self.hidden_dim, self.n_layers = (hidden_dim, n_layers)
        self.rnn = nn.LSTM(input_dim,hidden_dim,n_layers,dropout=dropout_perc)
    def forward(self,x):
        outputs, (hidden, cell) = self.rnn(x)
        return hidden, cell

pile=[]
for i in range(500):
    pile.append(Encoder(102,64,4,0.5).to(device))

del pile
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

I’m running PyTorch 1.5.1 and Python 3.8.3 on Ubuntu 18.04 LTS.

Simpler models don’t seem to exhibit this behavior. For example, this code fully deallocates the memory once all the references are deleted:

import torch
import torch.nn as nn
import gc

class Bloat(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.fc = nn.Linear(size,size)
    def forward(self,x):
        return self.fc(x)

pile = []
for i in range(10):
    pile.append(Bloat(2**12))

del pile
gc.collect()

Are you only seeing the increase on system memory usage, if you are using the GPU or also a CPU implementation only?

Thanks for replying @ptrblck .

When using torch.device('cpu') the memory usage of allocating the LSTM module Encoder increases and never comes back down.
When using torch.device('cuda:0') the memory usage of the same comes down out of the GPU, and most of it comes down out of the system RAM as well. (I just did the experiment, and there was 16M unaccountably still allocated in system RAM).

So the problem seems at least mostly restricted to torch.device('cpu').

Oh, I just realized you might mean a CPU-only distribution of PyTorch? I’m using a CUDA-enabled version, but with a CPU device. I haven’t tried with a CPU-only version of PyTorch because I do train on a GPU occasionally, though if this bug (if it is a bug) isn’t on a CPU only version of PyTorch, I can definitely switch to that for the actual deployment.

I just tried this on my Mac using the CPU-only distribution of PyTorch 1.5.1 for MacOS, and it did free all of its memory after the del pile; gc.collect(). So perhaps this bug only affects using a CPU device on the GPU-capable distribution of PyTorch. That at least gives me an angle of attack, though it would be far more convenient if the GPU-capable distribution of PyTorch behaved itself in CPU-mode for development and testing.

Ah, blast. I just tried 1.5.1+cpu on Linux, and it didn’t free the memory. The Mac version freed the memory after the del pile; gc.collect(), but the Linux version didn’t.

Also just tried 1.6.0.dev20200625+cpu on Linux, and it didn’t free the memory. The resident set only every increases, never decreases, until I kill the Python process itself.

Also interesting: I can see that there aren’t any references to tensors left with this code. The result is only the throwaway tensor after del pile:

for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(f'type:{type(obj)}; shape:{obj.size()}; grad_fn:{repr(obj.grad_fn)}; requires_grad:{obj.requires_grad}')
    except Exception as e:
        pass

So it’s not that there’s still some kind of lingering reference.

One last thing: If I do this multiple times, say filling pile, deleting pile, and then filling it again, the memory usage doesn’t go up significantly until I fill pile more than I filled it the last time. So I think this is some kind of ever-expanding heap issue.

Thanks for the debugging so far.
Could you post, how you are measuring the allocated system memory? Are you checking it in a system utility or inside the Python process directly?

I’m also having the same CPU memory issue with LSTMs, strangely it is only affected if the batch size and hidden size are above a certain level. In the example code below, the memory measurement is on the Python process directly.

https://discuss.pytorch.org/t/lstm-cpu-memory-leak-on-specific-batchsize-and-hidden-size/89135

@ptrblck I’m measuring allocated system memory by watching the memory for the Python process in htop rise and fall as I run my script in a repl.

I also just did a longer experiment using the MacOS version of CPU-only PyTorch 1.5.1 (which is the only place I’ve seen the LSTM memory released correctly so far). The memory usage remained bounded (below 400MB), and it was completely released at each stage.

I just started running the same script on the same data, with the only difference being the Linux version of CPU-only PyTorch 1.5.1 (this one specifically: https://download.pytorch.org/whl/cpu/torch-1.5.1%2Bcpu-cp38-cp38-linux_x86_64.whl). In about five minutes it has already consumed a gigabyte, and continues to climb.

I’ll see if I can figure out what’s different between the two versions of CPU-only PyTorch tomorrow, but I’m out of my depth if it’s an MKL-DNN bug or something.

By the way, the script I’m currently running involves my actual application code, including training, which I can’t share. However, if it would be helpful, I’d be happy to craft a minimal example I can share that exhibits the same behavior. It doesn’t appear hard to replicate.

@pol1 That’s interesting. My own parameters for the LSTM module are like so in my leaking example:

batch_size = 16
hidden_size = 64
n_layers = 4

Does reducing it to batch_size=1 and hidden_size=16 help? Reducing them worked in my minimal example

I’ll give it a shot tomorrow.

1 Like

A minimal code snippet to reproduce this issue would be great, but your initial code might also do the job?
Please create an issue here with the code snippet, a brief description, and link to this topic for further information, so that we can track and fix it.

CC @pol1 in case you would like to add your information to the same issue.

Was able to stabilize the leak by setting OMP_NUM_THREADS=4. More info is here: https://github.com/pytorch/pytorch/issues/32008

A fix has been merged and looks to be available in future version 1.6. The nightly build works for this similar case: https://github.com/pytorch/pytorch/issues/40973`

1 Like

Hmm… I’ll definitely try the workarounds in this post, but if 1.6 is supposed to have resolved it, my issue may be different. The 1.6.0.dev20200625+cpu nightly leaked in my above allocation test.

Do you know if the fix for the issue you linked would have been in by then? It’s the latest 1.6 nightly I could find.

I think I’m convinced I have a distinct issue now. Neither using 1.6, nor MKL_DISABLE_FAST_MM=1, nor OMP_NUM_THREADS=4 or a combination thereof solves the leak in my allocation test. I’ll file a PyTorch issue now.

Here’s the new issue: https://github.com/pytorch/pytorch/issues/41486

@pol1 reducing the hidden size didn’t solve the problem in my allocation test, so I’m pretty sure I have a different problem.