Hello, I’ve encountered a memory leak on a LSTM model and condensed the issue into the following code.
I’ve tried the following solutions:
- Detach hidden state using repackage_hidden() https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226/8
- gc.collect()
- torch.nograd()
However, problem still persists using pytorch 1.5.1, and my machine has 64GB of ram. Any help is appreciated!
import torch
import torch.nn as nn
import psutil, os, gc
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
# Doesn't leak memory
# batch = 3
# hidden_size = 256
# batch = 3
# hidden_size = 512
# batch = 6
# hidden_size = 256
# Leaks memory
batch = 6
hidden_size = 512
rnn = nn.LSTM(320, hidden_size, num_layers=5, bidirectional=True)
x = torch.randn(5, batch, 320)
h0 = torch.randn(10, batch, hidden_size)
c0 = torch.randn(10, batch, hidden_size)
with torch.no_grad():
for i in range(1000):
print(i, psutil.Process(os.getpid()).memory_info().rss)
output, hidden = rnn(x, (h0, c0))
hidden = repackage_hidden(hidden)
gc.collect()