Initializing the right layers for bidirectional LSTM

Hey guys!

I’m currently building a bidirectional LSTM for text classification. I’m pretty new to coding with pytorch and I’m wondering which layers I need for the bi-LSTM to work correctly. It’s currently running without errors, but I wonder if the bidirectional part is also working, or if it’s still a normal LSTM:

class LSTM(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size, batch_size):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)
        self.hidden2label = nn.Linear(hidden_dim * 2, label_size) # Ä: hidden_dim * 2 for bidirectional
        self.hidden = self.init_hidden()

    def init_hidden(self):

        h0 = Variable(torch.zeros(1 * 2, self.batch_size, self.hidden_dim))
        c0 = Variable(torch.zeros(1 * 2, self.batch_size, self.hidden_dim))
        return h0, c0

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        x = embeds.view(len(sentence), self.batch_size, -1)
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        y = self.hidden2label(lstm_out[-1])
        return y

Does bidirectional=True do anything by itself? Or do I need to add another backward_lstm for doing the same thing, but the other way around?
And if so, how can I feed the backward layer the input in reverse?

Sorry I’m pretty new to all this stuff and I’m sometimes really unsure about what is doing on in pytorch (._.)

Thank you!

Found it out by myself:

y = self.hidden2label(torch.cat((lstm_out[-1, :, :self.hidden_dim], lstm_out[0, :, self.hidden_dim:]), 1))

By replacing the secondlast line with that, I get the last hidden state from the forward output and the last (“first” position in the input) hidden state from the backward output and concatenate them. So yeah, works pretty well, thanks to that:

1 Like

@smth I’m still confused ! Is this interpretation correct? Won’t passing the whole input that we usually do contain this subset of input anyway?

Sorry, can you specify your issue?