Bidirectional LSTM for text classification

Hello,

I’m trying to train a bidirectional LSTM for multi-label text classification. I’m using pre-trained w2v vectors to represent words. My input consists of indices to the word embeddings (padded with 0s), and lengths of sequences sorted in a decreasing order. My model looks like this:

class EmailLSTM(nn.Module):

def __init__(self, input_size, hidden_size, num_classes, num_layers, dropout, weights):
    super(EmailLSTM, self).__init__()
    self.num_layers = num_layers
    self.hidden_size = hidden_size
    self.embedding = nn.Embedding.from_pretrained(weights, padding_idx=0)
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, 
                                                                             dropout=dropout)
    self.fc = nn.Linear(hidden_size * 2, num_classes)

def init_hidden(self, batch_size):
    hidden = (torch.zeros(self.num_layers * 2, batch_size, self.hidden_size),
              torch.zeros(self.num_layers * 2, batch_size, self.hidden_size))
    return hidden

def forward(self, X, lengths):
    text_emb = self.embedding(X)

    hidden = self.init_hidden(X.shape[0])

    packed_input = pack_padded_sequence(text_emb, lengths, batch_first=True)
    packed_output, (h_n, _) = self.lstm(packed_input, hidden)

    forward_hidden = h_n[0]  # Forward hidden state at the last layer
    backward_hidden = h_n[1]  # Backward hidden state at the last layer

    combined_hidden = torch.cat((forward_hidden, backward_hidden), dim=1)

    logits = self.fc(combined_hidden)

    return logits

The loss function I’m using is BCEWithLogitsLoss, and I pass pos_weight parameter to it which is calculated according to docs. My issue is that the validation loss doesn’t seem to drop during training, so I’m not sure if my architecture / loss function are correct. Therefore, my question are:

  1. Is this the correct way to deal with variable-length inputs?
  2. Does my architecture look appropriate for the task?
  3. Is the loss function appropriate for the task?

Any help would be much appreciated!

I think this part is a bit off. h_n has a shape of (num_direction*num_layers, batch_size, hidden_size), so you code only works for num_layers=1. For more layers, you will pick the first, not the last. My suggestion – rather verbose for clarity:

# Split layer and direction dimension (according to docs)
h_n = h_n.view(num_layers, num_directions, batch_size, hidden_size)
# The last hidden state (last w.r.t. number of layers)
h_n = h_n[-1]
# Concatenate directions
h_fwd, h_bwd = h_n[0], h_n[1]
h_n = torch.cat((h_fwd, h_bwd), dim=1)  # Concatenate both states