Training LSTM, loss not decreasing


(Danny Sanchez) #1

Hi all,

I’m training an LSTM as an encoder for sentences. My loss function is torch.nn.MultiMarginLoss with the default parameters. For more context, here’s a link to the paper:

Here’s my lstm implementation (as a note I’m feeding in batches of sentence word embedding vectors. Each has a variable length (length of the corresponding sentence) which is padded by pack_padded_sequence):

import torch
import torch.nn as nn
from torch.autograd import Variable

torch.manual_seed(1)

class LSTM(nn.Module):
    def __init__(self, config):
        super(LSTM, self).__init__()
        self.config = config
        self.lstm = nn.LSTM(input_size=config["input_size"],
                            hidden_size=config["hidden_size"],
                            num_layers=config["num_layers"],
                            dropout=config["dropout"],
                            bidirectional=config["bidirectional"])


    def forward(self, inputs, lengths=None):
        batch_size = inputs.size()[0]
        if lengths is not None:
            inputs = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, True)

        state_shape = self.config["num_cells"], batch_size, self.config["hidden_size"]
        h0 = c0 = Variable(inputs.data.data.new(*state_shape).zero_())

        outputs, (ht, ct) = self.lstm(inputs, (h0, c0))
        if self.config["bidirectional"]:
            return ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
        return ht[-1]

The following code is where I call my loss function:

criterion = torch.nn.MultiMarginLoss()
optimizer = optim.Adam(self.lstm.parameters(), lr=LEARNING_RATE)
...
for batch_num, (title_batch, body_batch, question_info_batch) in \
        enumerate(self.builder.getFeatures(self.train_path)):

    title_batch, title_lengths = self.builder.computePaddedBatch(title_batch)
    body_batch, body_lengths = self.builder.computePaddedBatch(body_batch)

    title_feature_matrix = self.getFeatureVariable(title_batch)
    body_feature_matrix = self.getFeatureVariable(body_batch)

    title_final_states = self.lstm(title_feature_matrix, title_lengths)
    body_final_states = self.lstm(body_feature_matrix, body_lengths)

    relevant_states = []
    retrieved_states = []
    q_final_state = None
    batch_id = None
    for j in range(len(title_final_states)):
        title_final_state = title_final_states[j:j+1]
        body_final_state = body_final_states[j:j+1]
        q_id, candidate_id, actual = question_info_batch[j]

        final_state = (title_final_state+body_final_state)/2.0

        question_info = self.builder.getQuestionInfo(q_id)
        title = question_info["title"]
        body = question_info["body"]

        title_vector = self.builder.getSentenceVector(title)
        body_vector = self.builder.getSentenceVector(body)

        title_feature_vector = self.getFeatureVariable([title_vector])
        body_feature_vector = self.getFeatureVariable([body_vector])

        q_title_final_state = self.lstm(title_feature_vector, [len(title_vector)])
        q_body_final_state = self.lstm(body_feature_vector, [len(body_vector)])

        q_final_state = (q_title_final_state+q_body_final_state)/2.0

        if actual:
            relevant_states.append((candidate_id, final_state))
        else:
            retrieved_states.append((candidate_id, final_state))

    x = []
    y = []
    scores = []
    neg_sims = []

    assert len(relevant_states) > 0
    assert len(retrieved_states) > 0

    for retrieved_id, retrieved_state in retrieved_states:
        sim = f.cosine_similarity(q_final_state, retrieved_state)
        neg_sims.append(sim)

    for relevant_id, relevant_state in relevant_states:
        sim = f.cosine_similarity(q_final_state, relevant_state)
        x.append(sim)
        x = x + neg_sims

        y.append(0)

    x = torch.cat(x)
    x = x.view(len(relevant_states), len(neg_sims)+1)

    y = Variable(torch.LongTensor(y))

    optimizer.zero_grad()
    loss = criterion(x, y)
    loss.backward()
    optimizer.step()

If there’s any additional I can provide that would help with answering the question please let me know!