I am writing a character level RNN for regression learning, where each sequence in my training set has a real-number value assigned to it.
I am using pack_padded_sequence
for padding my sequences to the same length in order to run batched learning. The code is below:
class GenericRNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size, train_embedding = True, device = "cuda"):
super(GenericRNN, self).__init__()
self.batch_size = batch_size
self.vocab_size, self.embedding_dim, self.hidden_dim = vocab_size, embedding_dim, hidden_dim
self.embedding = nn.Embedding(vocab_size , embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, 1)
self.hidden = self.init_hidden(batch_size, device)
self.device = device
if train_embedding is False:
self.embedding.weight.requires_grad = False
def forward(self, char_index, length):
seq_lengths, perm_idx = length.sort(0, descending=True)
seq_tensor = char_index[perm_idx]
embeds = self.embedding(seq_tensor)
packed_input = pack_padded_sequence(embeds, seq_lengths, batch_first=True)
output, self.hidden = self.rnn(packed_input, self.hidden)
output, input_sizes = pad_packed_sequence(output, batch_first=True)
mask = torch.zeros(output.shape[:-1]).to(self.device)
mask[torch.arange(output.shape[0]), input_sizes - 1] = 1
mask = mask.unsqueeze(2).byte()
return self.fc(output.masked_select(mask).view(-1, self.hidden_dim))
def init_hidden(self, batch_size):
return (torch.zeros(1, batch_size, self.hidden_dim).to(self.device), torch.zeros(1, batch_size, self.hidden_dim).to(self.device))
What I find is when I use a batch_size of 1, the loss decreases very well but as soon as I increase the batch_size, the loss no longer decreases at all.
I think one possible cause is that my sequences vary in length quite a lot and I also need to somehow account for all the padding in my loss function (currently using L1Loss
). I see for classification, the the CrossEntropy loss allows ignoring pad index. Is there something similar for regression tasks?