I’m trying to train a language model on Penn Treebank, similar to the example here: https://github.com/pytorch/examples/tree/master/word_language_model only I’m using full sentences (so varying lengths) instead of fixed-sized sequences of words. My model is a bidirectional LSTM:
class BiLSTM(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout_prob=0.5):
super(BiLSTM, self).__init__()
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = nn.Dropout(p=dropout_prob)
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout_prob, bidirectional=True)
self.fc = nn.Linear(2 * hidden_size, vocab_size)
def forward(self, input, lengths, hidden):
embed = self.dropout(self.embedding(input))
packed = pack_padded_sequence(embed, lengths)
packed_out, hidden = self.lstm(packed, hidden)
out, _ = pad_packed_sequence(packed_out)
out = self.dropout(out)
out = self.fc(out)
return out, hidden
def init_hidden(self, batch_size):
return (torch.zeros(2 * self.num_layers, batch_size, self.hidden_size).to(device),
torch.zeros(2 * self.num_layers, batch_size, self.hidden_size).to(device))
I think I’m doing the packing and padding correctly, and the model runs, but I don’t think it trains…
I’m sorting by batches by length, as needed for pack_padded_sequence
, and the training loop is this:
loss_function = nn.CrossEntropyLoss(ignore_index=PAD).to(device)
optimizer = optim.Adam(model.parameters())
epochs = 40
train_batches = list(get_batches(train, BATCH_SIZE))
valid_batches = list(get_batches(valid, 1))
best_val_loss = float('inf')
for epoch in range(epochs):
total_train_loss = 0
total_val_loss = 0
model.train()
for batch in tqdm(train_batches):
model.zero_grad()
X, y, lengths = batch
_, batch_size = X.size()
hidden = model.init_hidden(batch_size)
yhat, hidden = model(X, lengths, hidden)
loss = loss_function(yhat.contiguous().view(-1, VOCAB_SIZE), y)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
hidden = (hidden[0].detach(), hidden[1].detach())
total_train_loss /= len(train_batches)
with torch.no_grad():
model.eval()
for batch in tqdm(valid_batches):
X, y, lengths = batch
_, batch_size = X.size()
hidden = model.init_hidden(batch_size)
yhat, hidden = model(X, lengths, hidden)
loss = loss_function(yhat.contiguous().view(-1, VOCAB_SIZE), y)
total_val_loss += loss.item()
hidden = (hidden[0].detach(), hidden[1].detach())
total_val_loss /= len(valid_batches)
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
else:
break
If I’m not dividing the total losses by the number of batches, the loss is somewhere in the tens of thousands. The yhat.contiguous().view(-1, VOCAB_SIZE)
is something I saw somewhere and decided to try and see if it makes a difference, and I don’t think it did. Before that, I just did yhat = yhat.permute(1, 2, 0)
, and the loop still ran without any runtime errors, but weird values. I tried changing from Adam to SGD, but it didn’t seem to affect anything.
Anyone has an idea how do I need to train this model? I’m really all out of ideas, and tutorials aren’t really helping
Thanks!