Bidirectional LSTM out of memory

I wrote a bidirectional LSTM as following, and I get out of memory error after training for a while. I believe it has something to do with creating variables in the middle of forward method, but I am not sure about it. Do you guys have any idea what might be the reason and how to fix it? Or is there a better way to write a bidirectonal LSTM?

class BiLSTMModel(nn.Module):
	def __init__(self, args):
		super(BiLSTMModel, self).__init__()
		self.embed = nn.Embedding(args.vocab_size, args.embedding_size)
		self.flstm = nn.LSTM(args.embedding_size, args.hidden_size, batch_first=True)
		self.blstm = nn.LSTM(args.embedding_size, args.hidden_size, batch_first=True)
		self.linear = nn.Linear(args.hidden_size*2, args.label_size), 1)

		self.use_cuda = args.use_cuda

	def forward(self, x, mask, is_eval=False):
			run the model, take input sentence and predict the logits
				x: encoded sentence
				mask: the mask of the sentence
		x_embd = self.embed(x)
		# forward lstm
		fout, (hn, cn) = self.flstm(x_embd)
		# calculate backward index
		rev_index = torch.range(x.size(1) - 1, 0, -1).view(1, -1).expand(x.size(0), x.size(1)).long()
		if self.use_cuda:
			rev_index = rev_index.cuda()
		# code.interact(local=locals())
		mask_length = torch.sum(1 -, 1).unsqueeze(1).long().expand_as(rev_index)
		rev_index -= mask_length
		rev_index[rev_index < 0] = 0
		rev_index = Variable(rev_index, volatile=is_eval)

		# reverse the order of x and store it in bx
		bx = Variable(, volatile=is_eval)
		bx = torch.gather(x, 1, rev_index)
		bx_embd = self.embed(bx)
		# backward lstm
		bout, (hn, cn) = self.blstm(bx_embd)

		# concat forward hidden states with backward hidden states
		out =[fout, bout], 2)
		length = mask.sum(1).unsqueeze(1).unsqueeze(2).expand(out.size(0), 1, out.size(2)).long() - 1

		# gather the last hidden states
		out = torch.gather(out, 1, length).contiguous().squeeze(1)
		out = self.linear(out)
		return out

nn.LSTM() has an argument bidirectional

in terms of memory, are you using nn.DataParallel?

Do you know if nn.LSTM(bidirectional=True) works with masking?