I would like to do binary sentiment classification of texts using an LSTM.
My problem is that the model trains for a batch size of 1 but not when processing multiple sentences in a batch.
I do not get runtime errors but the model simply does not learn anything for higher batch sizes, so I suspect something might be wrong with the padding or how I use pack/pad_padded_sequence in the LSTM.
This is my model:
class RNN(nn.Module): def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, num_classes): super(RNN, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_size) self.hidden_size = hidden_size self.num_layers = num_layers self.type = type self.recurrent_layer = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=0.5, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes) def init_hidden(self, batch_size): h_0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)) c_0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)) if torch.cuda.is_available(): h_0 = h_0.cuda() c_0 = c_0.cuda() return (h_0, c_0) def forward(self, inputs, lengths): embedded = self.embedding(inputs) embedded = nn.utils.rnn.pack_padded_sequence(embedded, list(lengths.data), batch_first=True) # pack batch initial_hidden_state = self.init_hidden(inputs.size()) r_out, last_hidden_state = self.recurrent_layer(embedded, initial_hidden_state) # pass in LSTM model r_out, recovered_lengths = nn.utils.rnn.pad_packed_sequence(r_out, batch_first=True) # unpack batch idx = (lengths - 1).view(-1, 1).expand(r_out.size(0), r_out.size(2)).unsqueeze(1) # get last hidden output of each sequence r_out = r_out.gather(1, idx).squeeze(dim=1) out = self.fc(r_out) return out
And this is how I train it:
def train(model, X_train, y_train, learning_rate, num_epochs, batch_size): # Loss and Optimizer criterion = nn.CrossEntropyLoss() # contains softmax layer and cross entropy loss, averages over examples in batch if torch.cuda.is_available(): criterion.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Train the Model for epoch in range(num_epochs): train_loss = 0.0 for i, (inputs, lengths, labels) in enumerate(get_minibatches(X_train, y_train, batch_size, shuffle=True)): inputs = Variable(torch.LongTensor(inputs)) labels = Variable(torch.LongTensor(labels)) lengths = Variable(torch.LongTensor(lengths)) if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() lengths = lengths.cuda() # Forward + Backward + Optimize optimizer.zero_grad() outputs = model(inputs, lengths) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.data print ('Epoch [%d/%d], Train loss: %.2f' %(epoch + 1, num_epochs, train_loss/(len(X_train)/batch_size)))
Here is how I create the minibatches and pad them
def pad(inputs): lengths = [len(x) for x in inputs] max_len = max(lengths) for input in inputs: for i in range(0, max_len - len(input)): input.append(voc['PAD']) return inputs, lengths def get_minibatches(inputs, targets, batch_size, shuffle=False): assert len(inputs) == len(targets) examples = zip(inputs, targets) if shuffle: random.shuffle(examples) # take steps of size batch_size, take at least one step for start_idx in range(0, max(batch_size, len(inputs) - batch_size + 1), batch_size): batch_examples = examples[start_idx:start_idx + batch_size] batch_inputs, batch_targets = zip(*batch_examples) # pad the inputs batch_inputs, batch_lengths = pad(batch_inputs) # sort according to length batch_inputs, batch_lengths, batch_targets = zip(*sorted(zip(batch_inputs, batch_lengths, batch_targets), key=operator.itemgetter(1), reverse=True)) yield list(batch_inputs), list(batch_lengths), list(batch_targets)
I have already checked that the inputs are padded correctly, the inputs, lengths, targets match in the batches, I have also looked at the results of pack_padded_sequence, pad_padded_sequence and the r_out.gather operation and verified that they look correct and the correct last LSTM state is selected.
However, the network does not learn anything for batch sizes higher than 1, the loss always stays the same throughout the epochs.
Can anyone spot what I overlooked?