Hello, I’ve encountered a memory leak on a LSTM model and condensed the issue into the following code.
I’ve tried the following solutions:
- Detach hidden state using repackage_hidden() https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226/8
However, problem still persists using pytorch 1.5.1, and my machine has 64GB of ram. Any help is appreciated!
import torch import torch.nn as nn import psutil, os, gc def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(repackage_hidden(v) for v in h) # Doesn't leak memory # batch = 3 # hidden_size = 256 # batch = 3 # hidden_size = 512 # batch = 6 # hidden_size = 256 # Leaks memory batch = 6 hidden_size = 512 rnn = nn.LSTM(320, hidden_size, num_layers=5, bidirectional=True) x = torch.randn(5, batch, 320) h0 = torch.randn(10, batch, hidden_size) c0 = torch.randn(10, batch, hidden_size) with torch.no_grad(): for i in range(1000): print(i, psutil.Process(os.getpid()).memory_info().rss) output, hidden = rnn(x, (h0, c0)) hidden = repackage_hidden(hidden) gc.collect()