Hi,
I have a significant Mem Leak when using LSTMCell with batchnorm. The memory consumption increases with each forward pass.
I have version 0.2.1, and most of the memory leaks issues I found in the form were fixed in previous version.
Attached is my code if you can please help. Thanks:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class LstmKpiPredictor(nn.Module):
_ def init(self, input_dim, embedding_dim, hidden_dim, batch_size):_
_ super(LstmKpiPredictor, self).init()_
_ self.input_dim = input_dim_
_ self.hidden_dim = hidden_dim_
_ self.embedding_dim = embedding_dim_
_ self.linear = nn.Linear(568, embedding_dim)_
_ self.linear2 = nn.Linear(embedding_dim, embedding_dim)_
_ self.lstm1 = nn.LSTMCell(embedding_dim, hidden_dim)_
_ self.lstm2 = nn.LSTMCell(hidden_dim, hidden_dim)_
_ self.lstm3 = nn.LSTMCell(embedding_dim, hidden_dim)_
_ self.normbatch_beforelstm = nn.BatchNorm1d(568)_
_ self.normbatch_afterlstm = nn.BatchNorm1d(batch_size)_
_ self.output = nn.Linear(hidden_dim, 1)_
_ self.output2 = nn.Linear(hidden_dim, 1)_
_ self.batch_size = batch_size_
_ def get_init_hidden_values(self):_
_ return Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False, volatile = True),_
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False),_
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False),_
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False),_
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False),_
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False)_
_ def forward(self, recordings, sequences_length):_
_ print(‘starting forward pass’)_
_ outputs = []_
_ after_normbatch = self.normbatch_beforelstm(recordings)_
_ for recording_idx, recording in enumerate(after_normbatch.chunk(after_normbatch.size(0), dim=0)):_
_ #run for each recording_
_ h_t, c_t, h_t2, c_t2 = Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False), _
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False), _
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False), _
_ Variable(torch.zeros(1, self.hidden_dim).float(), requires_grad=False)_
_ if torch.cuda.is_available():_
_ h_t = h_t.cuda()_
_ c_t = c_t.cuda()_
_ h_t2 = c_t.cuda()_
_ c_t2 = c_t.cuda()_
_ # h_t3 = c_t.cuda()_
_ # c_t3 = c_t.cuda()_
_ for event_idx, event in enumerate(recording.chunk(recording.size(2), dim=2)):_
_ #until reached the end of recording_
_ data = event.data.clone()_
_ if event_idx == sequences_length[recording_idx]:_
_ break_
_ event_data = Variable(data.view(data.shape[1]))_
_ event_data = F.relu(self.linear(event_data))_
_ h_t, c_t = self.lstm1(event_data, (h_t, c_t))_
_ h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))_
_ outputs += [h_t2]_
_ outputs = torch.stack(outputs, 1)_
_ outputs_after_norm = self.normbatch_afterlstm(outputs)_
_ predict_vector = self.output(outputs_after_norm)_
_ print(‘finished forward pass’)_
_ return predict_vector_