Hi,
I would like to do binary sentiment classification of texts using an LSTM.
My problem is that the model trains for a batch size of 1 but not when processing multiple sentences in a batch.
I do not get runtime errors but the model simply does not learn anything for higher batch sizes, so I suspect something might be wrong with the padding or how I use pack/pad_padded_sequence in the LSTM.
This is my model:
class RNN(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.type = type
self.recurrent_layer = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=0.5, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def init_hidden(self, batch_size):
h_0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
c_0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
if torch.cuda.is_available():
h_0 = h_0.cuda()
c_0 = c_0.cuda()
return (h_0, c_0)
def forward(self, inputs, lengths):
embedded = self.embedding(inputs)
embedded = nn.utils.rnn.pack_padded_sequence(embedded, list(lengths.data), batch_first=True) # pack batch
initial_hidden_state = self.init_hidden(inputs.size()[0])
r_out, last_hidden_state = self.recurrent_layer(embedded, initial_hidden_state) # pass in LSTM model
r_out, recovered_lengths = nn.utils.rnn.pad_packed_sequence(r_out, batch_first=True) # unpack batch
idx = (lengths - 1).view(-1, 1).expand(r_out.size(0), r_out.size(2)).unsqueeze(1)
# get last hidden output of each sequence
r_out = r_out.gather(1, idx).squeeze(dim=1)
out = self.fc(r_out)
return out
And this is how I train it:
def train(model, X_train, y_train, learning_rate, num_epochs, batch_size):
# Loss and Optimizer
criterion = nn.CrossEntropyLoss() # contains softmax layer and cross entropy loss, averages over examples in batch
if torch.cuda.is_available():
criterion.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the Model
for epoch in range(num_epochs):
train_loss = 0.0
for i, (inputs, lengths, labels) in enumerate(get_minibatches(X_train, y_train, batch_size, shuffle=True)):
inputs = Variable(torch.LongTensor(inputs))
labels = Variable(torch.LongTensor(labels))
lengths = Variable(torch.LongTensor(lengths))
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
lengths = lengths.cuda()
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(inputs, lengths)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.data[0]
print ('Epoch [%d/%d], Train loss: %.2f' %(epoch + 1, num_epochs, train_loss/(len(X_train)/batch_size)))
Here is how I create the minibatches and pad them
def pad(inputs):
lengths = [len(x) for x in inputs]
max_len = max(lengths)
for input in inputs:
for i in range(0, max_len - len(input)):
input.append(voc['PAD'])
return inputs, lengths
def get_minibatches(inputs, targets, batch_size, shuffle=False):
assert len(inputs) == len(targets)
examples = zip(inputs, targets)
if shuffle:
random.shuffle(examples)
# take steps of size batch_size, take at least one step
for start_idx in range(0, max(batch_size, len(inputs) - batch_size + 1), batch_size):
batch_examples = examples[start_idx:start_idx + batch_size]
batch_inputs, batch_targets = zip(*batch_examples)
# pad the inputs
batch_inputs, batch_lengths = pad(batch_inputs)
# sort according to length
batch_inputs, batch_lengths, batch_targets = zip(*sorted(zip(batch_inputs, batch_lengths, batch_targets), key=operator.itemgetter(1), reverse=True))
yield list(batch_inputs), list(batch_lengths), list(batch_targets)
I have already checked that the inputs are padded correctly, the inputs, lengths, targets match in the batches, I have also looked at the results of pack_padded_sequence, pad_padded_sequence and the r_out.gather operation and verified that they look correct and the correct last LSTM state is selected.
However, the network does not learn anything for batch sizes higher than 1, the loss always stays the same throughout the epochs.
Can anyone spot what I overlooked?