I wrote a bidirectional LSTM as following, and I get out of memory error after training for a while. I believe it has something to do with creating variables in the middle of forward method, but I am not sure about it. Do you guys have any idea what might be the reason and how to fix it? Or is there a better way to write a bidirectonal LSTM?
class BiLSTMModel(nn.Module):
def __init__(self, args):
super(BiLSTMModel, self).__init__()
self.embed = nn.Embedding(args.vocab_size, args.embedding_size)
self.flstm = nn.LSTM(args.embedding_size, args.hidden_size, batch_first=True)
self.blstm = nn.LSTM(args.embedding_size, args.hidden_size, batch_first=True)
self.linear = nn.Linear(args.hidden_size*2, args.label_size)
self.embed.weight.data.uniform_(-1, 1)
self.use_cuda = args.use_cuda
def forward(self, x, mask, is_eval=False):
'''
run the model, take input sentence and predict the logits
args:
x: encoded sentence
mask: the mask of the sentence
returns:
'''
x_embd = self.embed(x)
# forward lstm
fout, (hn, cn) = self.flstm(x_embd)
# calculate backward index
rev_index = torch.range(x.size(1) - 1, 0, -1).view(1, -1).expand(x.size(0), x.size(1)).long()
if self.use_cuda:
rev_index = rev_index.cuda()
# code.interact(local=locals())
mask_length = torch.sum(1 - mask.data, 1).unsqueeze(1).long().expand_as(rev_index)
rev_index -= mask_length
rev_index[rev_index < 0] = 0
rev_index = Variable(rev_index, volatile=is_eval)
# reverse the order of x and store it in bx
bx = Variable(x.data.new(x.size()).fill_(0), volatile=is_eval)
bx = torch.gather(x, 1, rev_index)
bx_embd = self.embed(bx)
# backward lstm
bout, (hn, cn) = self.blstm(bx_embd)
# concat forward hidden states with backward hidden states
out = torch.cat([fout, bout], 2)
length = mask.sum(1).unsqueeze(1).unsqueeze(2).expand(out.size(0), 1, out.size(2)).long() - 1
# gather the last hidden states
out = torch.gather(out, 1, length).contiguous().squeeze(1)
out = self.linear(out)
return out