Code Review - Sentiment Analysis [Beginner]

Hm, two things stand out to me.

Firstly, it seems you define your model modelObj for every batch. That creates the model every time and you loose all your changes to the weights. This should be done only once, so I would move this line, say, before dataloader1 = ....

Secondly, I’m still a bit suspicious about the forward() method. Maybe all is correct, but I can test it. It’s a bit strange that you permute the embeddings. Usually the input and therefore the embeddings have batch_size as the first dimension. You also define your nn.LSTM with batch_first=True, so all should be good with output permuting. However, since your code runs, I assume it must be correct.

However, the processing of hidden seems a bit off. hidden first has a shape of (num_layers * num_directions, batch, hidden_size). After permuting it’s (batch, num_layers * num_directions, hidden_size), and after the .view() it’s (batch, num_layers * num_directions * hidden_size) – while people usually use only the last layer, this shouldn’t be a problem in principle. But again, I’m not sure if the data is not messed up, similar to example I’ve already linked to.

To get a first basic version running, I would do the following:

  def forward(self,input,hidden=None,verbose=False):
    embeds = self.embed(input).permute(1,0,2)  # <-- double-check this :)
    output,(hidden,cell) = self.lstm_cell(embeds,hidden)
    # Split layers and directions (useful if you want to try bidirectiona=Truel later on)
    hidden = hidden.view(selfnum_layers, self.num_directions, batch, self.hidden_size)
    # Get the last hidden state with respect to the layers
    hidden = hidden[-1]
    # Get rid of the direction dimension (won't work for bidirectional=True)
    hidden = hidden.squeeze(0)
    # the shape of hidden is now (batch, hidden_size)
    # so self.lf needs to abe nn.Linear(self.hidden_size,output_size)
    linear = self.lf(hidden)
    return linear

Here’s a complete example for an GRU/LSTM-based text classifier. The important part is the forward() method and the handling of the hidden state. This model solves exactly your task and even comes with attentions :slight_smile: