LSTM for document level classification?

I wrote some code for document level classification.
x is a three dimensional document of batch_size * number_of_sentences * sentence_length.
I first convert it to a two dimensional tensor, then run an embedding layer. Then I get the last hidden state using “gather” by the sentence length computed from the mask.
Then I compute the average of sentence embeddings from LSTM last hidden state.
When I run my code, it works very badly. The performance is even worse than simple word averaging. I feel there must be some problem with my code. Is that because I cannot change the value of “length” in the middle of the computation graph? Or is there any obvious error in my code?

class LSTMModel(nn.Module):
def __init__(self, args):
	super(LSTMModel, self).__init__()
	self.embedding_size = args.embedding_size
	self.hidden_size = args.hidden_size
	self.embed = nn.Embedding(args.vocab_size, args.embedding_size)
	self.lstm = nn.LSTM(self.embedding_size, self.hidden_size, batch_first=True)
	self.linear = nn.Linear(self.hidden_size, args.num_targets)
	self.embed.weight.data.uniform_(-1, 1)

def forward(self, x, mask, is_eval=False):
	x_reshaped = x.view(x.size(0)*x.size(1), x.size(2)) # (B*num_sentences)*sentence_length
	x_embd = self.embed(x_reshaped) # (B*num_sentences)*sentence_length*embedding_size
	output, (hn, cn) = self.lstm(x_embd) # output: (B*num_sentences)*sentence_length*hidden_size
	length = mask.view(x.size(0)*x.size(1), x.size(2)).sum(1).unsqueeze(1).unsqueeze(2).expand(output.size(0), 1, output.size(2)).long() - 1
	length.data[length.data < 0] = 0 # (B*num_sentences)*1*hidden_size
	out = torch.gather(output, 1, length).squeeze(1) # (B*num_sentences) * hidden_size
	out = out.contiguous().view(x.size(0), x.size(1), self.hidden_size)

	sentence_sum = torch.sum(out, 1)
	num_sentences = (torch.sum(mask[:, :, 0], 1)).unsqueeze(1).expand_as(sentence_sum) + 1e-9
	doc_embd = sentence_sum / num_sentences

	out = self.linear(doc_embd)
	return out

Looks like it’s the problem of setting too big learning rate LOL… But there might be other reasons as well.