Bidirectional LSTM doesn't converge on plausible values


(Ben Eyal) #1

I’m trying to train a language model on Penn Treebank, similar to the example here: https://github.com/pytorch/examples/tree/master/word_language_model only I’m using full sentences (so varying lengths) instead of fixed-sized sequences of words. My model is a bidirectional LSTM:

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout_prob=0.5):
        super(BiLSTM, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout(p=dropout_prob)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout_prob, bidirectional=True)
        self.fc = nn.Linear(2 * hidden_size, vocab_size)
    
    def forward(self, input, lengths, hidden):
        embed = self.dropout(self.embedding(input))
        packed = pack_padded_sequence(embed, lengths)
        packed_out, hidden = self.lstm(packed, hidden)
        out, _ = pad_packed_sequence(packed_out)
        out = self.dropout(out)
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        return (torch.zeros(2 * self.num_layers, batch_size, self.hidden_size).to(device),
                torch.zeros(2 * self.num_layers, batch_size, self.hidden_size).to(device))

I think I’m doing the packing and padding correctly, and the model runs, but I don’t think it trains…
I’m sorting by batches by length, as needed for pack_padded_sequence, and the training loop is this:

loss_function = nn.CrossEntropyLoss(ignore_index=PAD).to(device)
optimizer = optim.Adam(model.parameters())
epochs = 40
train_batches = list(get_batches(train, BATCH_SIZE))
valid_batches = list(get_batches(valid, 1))
best_val_loss = float('inf')
for epoch in range(epochs):
    total_train_loss = 0
    total_val_loss = 0
    model.train()
    for batch in tqdm(train_batches):
        model.zero_grad()
        
        X, y, lengths = batch
        _, batch_size = X.size()
        hidden = model.init_hidden(batch_size)
        
        yhat, hidden = model(X, lengths, hidden)
        loss = loss_function(yhat.contiguous().view(-1, VOCAB_SIZE), y)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        hidden = (hidden[0].detach(), hidden[1].detach())
    total_train_loss /= len(train_batches)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(valid_batches):
            X, y, lengths = batch
             _, batch_size = X.size()
            hidden = model.init_hidden(batch_size)

            yhat, hidden = model(X, lengths, hidden)
            loss = loss_function(yhat.contiguous().view(-1, VOCAB_SIZE), y)
            total_val_loss += loss.item()
            hidden = (hidden[0].detach(), hidden[1].detach())
    total_val_loss /= len(valid_batches)
    if total_val_loss < best_val_loss:
        best_val_loss = total_val_loss
    else:
        break

If I’m not dividing the total losses by the number of batches, the loss is somewhere in the tens of thousands. The yhat.contiguous().view(-1, VOCAB_SIZE) is something I saw somewhere and decided to try and see if it makes a difference, and I don’t think it did. Before that, I just did yhat = yhat.permute(1, 2, 0), and the loop still ran without any runtime errors, but weird values. I tried changing from Adam to SGD, but it didn’t seem to affect anything.

Anyone has an idea how do I need to train this model? I’m really all out of ideas, and tutorials aren’t really helping :frowning_face:

Thanks!


Don't Understand Loss Calculation