LSTM classification training loss doesn't change

I’m training on an easy LSTM classifier for a 3-class classification task. I wasn’t expecting any of these issues, and I could find where I got wrong in my code.

I’m currently using pytorch built-in Embedding for the pre-processed one-hot embedding. And the rest should be quite straightforward.

import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
import torch.optim as optim

vocab_size = len(vocab_ids) + 1
embedding_dim = 100
hidden_dim = 128
num_layers = 1

class lstm_sentiment(torch.nn.Module):
    
    def __init__(self) :
        super(lstm_sentiment, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers = num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(hidden_dim, 3)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x, hidden):
        embeds = self.embedding(x)
        output, hidden = self.lstm(embeds, hidden)
        output = output[:,-1,:]
        output = self.dropout(output)
        output = self.linear(output)
        output = F.relu(output)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        hidden = (torch.zeros(num_layers, batch_size, hidden_dim).cuda(), 
                  torch.zeros(num_layers, batch_size, hidden_dim).cuda())
        return hidden

There are several places that I’m not quite sure.

  1. In LSTM there needs to be a hidden layer, which we can initialize it by either using random numbers or using all zeros. But I am struggling with the place where I call init_hidden() method. In a pytorch official tutorial which used LSTM, init_hidden() was called inside forward(self, x) function. But I have different opinion. In my understanding, if I return the hidden layer outputted by LSTM from forward(self, x, hidden) to my training process, the next batch in the same epoch will continue using the hidden states outputted from last batch, instead of initializing a new hidden layer with all zeros or all random numbers. I think it won’t benefit the training if initializing a new hidden layer every time entering forward(self, x) method. I’m not sure if I’m correct.
  2. When I’m calling model.init_hidden() in training process, it seems I had to call hidden = tuple([each.data for each in hidden]), which I’m not quite sure why.
  3. My whole dataset is of size 100,000 but now I am only using 10,000 to speed it up. And I don’t think it would influence the result that much.
lr = 0.05
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
epochs = 50
print_every = 1000
clip = 2  # gradient clipping

model = lstm_sentiment()
model.cuda()
model.train()

training_losses = []

for epoch in range(epochs):
    temp_loss = []
    hidden = model.init_hidden()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        inputs = batch[0].cuda()
        labels = batch[1].cuda()
        hidden = tuple([each.data for each in hidden])
        output, hidden = model(inputs, hidden)
        loss = criterion(output, labels)
        loss.backward()
        temp_loss.append(loss.detach().cpu().numpy())
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
    training_losses.append(sum(temp_loss) / len(temp_loss))
    print("Epoch: {}/{}...".format(epoch + 1, epochs),
         "Loss: {:.6f}...".format(sum(temp_loss) / len(temp_loss)))

Here’s the output result:

Epoch: 1/50... Loss: 1.098612...
Epoch: 2/50... Loss: 1.098612...
Epoch: 3/50... Loss: 1.098612...
Epoch: 4/50... Loss: 1.098612...
Epoch: 5/50... Loss: 1.098612...
Epoch: 6/50... Loss: 1.098612...
Epoch: 7/50... Loss: 1.098612...
Epoch: 8/50... Loss: 1.098613...
Epoch: 9/50... Loss: 1.104512...
Epoch: 10/50... Loss: 1.098612...
Epoch: 11/50... Loss: 1.098612...
Epoch: 12/50... Loss: 1.098612...
Epoch: 13/50... Loss: 1.098612...
Epoch: 14/50... Loss: 1.098612...
Epoch: 15/50... Loss: 1.098612...
Epoch: 16/50... Loss: 1.098612...
Epoch: 17/50... Loss: 1.098612...
Epoch: 18/50... Loss: 1.098612...
Epoch: 19/50... Loss: 1.098612...
Epoch: 20/50... Loss: 1.098612...
Epoch: 21/50... Loss: 1.098612...
Epoch: 22/50... Loss: 1.098612...
Epoch: 23/50... Loss: 1.098612...
Epoch: 24/50... Loss: 1.098612...
Epoch: 25/50... Loss: 1.098612...
Epoch: 26/50... Loss: 1.098612...
Epoch: 27/50... Loss: 1.098612...
Epoch: 28/50... Loss: 1.098612...
Epoch: 29/50... Loss: 1.098612...
Epoch: 30/50... Loss: 1.098612...
Epoch: 31/50... Loss: 1.098612...
Epoch: 32/50... Loss: 1.098612...
Epoch: 33/50... Loss: 1.098612...
Epoch: 34/50... Loss: 1.098612...
Epoch: 35/50... Loss: 1.098612...
Epoch: 36/50... Loss: 1.098612...
Epoch: 37/50... Loss: 1.098612...
Epoch: 38/50... Loss: 1.098612...
Epoch: 39/50... Loss: 1.098612...
Epoch: 40/50... Loss: 1.098612...
Epoch: 41/50... Loss: 1.098612...
Epoch: 42/50... Loss: 1.098612...
Epoch: 43/50... Loss: 1.098612...
Epoch: 44/50... Loss: 1.098612...
Epoch: 45/50... Loss: 1.098612...
Epoch: 46/50... Loss: 1.098612...
Epoch: 47/50... Loss: 1.098612...
Epoch: 48/50... Loss: 1.098612...
Epoch: 49/50... Loss: 1.098612...
Epoch: 50/50... Loss: 1.098612...

It turns out that the loss doesn’t change. I thought I fully understood LSTM but it seems I’m wrong. Can anyone help me please. PyTorch is really much better than tensorflow, which I don’t want to give it up.

More strange thing are the following, I slightly modified some places, and the loss started to change but still doesn’t converge. I wonder if someone could help me with this.

import torch.nn as nn
import torch.nn.functional as F

vocab_size = len(vocab_ids) + 1
output_size = 3
embedding_dim = 100
hidden_dim = 128
num_layers = 1

class sentiment_lstm(nn.Module):
    """An LSTM model for text classification."""
    def __init__(self):
        super(sentiment_lstm, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=False)
        # fully connected
        self.linear = nn.Linear(hidden_dim, output_size)
        # softmax
        self.softmax = nn.LogSoftmax(dim=1)
        # dropout layer
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x, hidden):
        embeds = self.embedding(x)
        out, hidden = self.lstm(embeds, hidden)
        # fully connected
        out = self.linear(out[:, -1, :])
        out = F.relu(out)
        # softmax
        out = self.softmax(out)
        # dropout
        out = self.dropout(out)
        return out, hidden
    
    def init_hidden(self):
        if (train_on_gpu):
            hidden = (torch.zeros(num_layers, batch_size, hidden_dim).cuda(), 
                      torch.zeros(num_layers, batch_size, hidden_dim).cuda())
        else:
            hidden = (torch.zeros(num_layers, batch_size, hidden_dim), 
                      torch.zeros(num_layers, batch_size, hidden_dim))
        return hidden
net = sentiment_lstm()
train_on_gpu = torch.cuda.is_available()
print("train_on_gpu:", train_on_gpu)

lr = 0.00005
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
epochs = 50
print_every = 100
clip = 5  # gradient clipping

training_losses = []

if (train_on_gpu):
    net.cuda()
    
net.train()
for epoch in range(epochs):
    hidden = net.init_hidden()
    temp_loss = []
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        if (train_on_gpu):
            inputs, labels = inputs.cuda(), labels.cuda()
        hidden = tuple([each.data for each in hidden])
        output, hidden = net(inputs, hidden)
        loss = criterion(output, labels)
        loss.backward()
        temp_loss.append(loss.detach().cpu().numpy())
        # prevent exploding gradient problem
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()
    training_losses.append(sum(temp_loss) / len(temp_loss))
    print("Epoch: {}/{}...".format(epoch + 1, epochs),
         "Loss: {:.6f}...".format(sum(temp_loss) / len(temp_loss)))
Epoch: 1/50... Loss: 0.862115...
Epoch: 2/50... Loss: 0.854747...
Epoch: 3/50... Loss: 0.841312...
Epoch: 4/50... Loss: 0.855828...
Epoch: 5/50... Loss: 0.878623...
Epoch: 6/50... Loss: 0.836831...
Epoch: 7/50... Loss: 0.856545...
Epoch: 8/50... Loss: 0.861678...
Epoch: 9/50... Loss: 0.855534...
Epoch: 10/50... Loss: 0.853875...
Epoch: 11/50... Loss: 0.842848...
Epoch: 12/50... Loss: 0.841553...
Epoch: 13/50... Loss: 0.868869...
Epoch: 14/50... Loss: 0.874237...
Epoch: 15/50... Loss: 0.860539...
Epoch: 16/50... Loss: 0.850290...
Epoch: 17/50... Loss: 0.872914...
Epoch: 18/50... Loss: 0.861847...
Epoch: 19/50... Loss: 0.854308...
Epoch: 20/50... Loss: 0.859626...
Epoch: 21/50... Loss: 0.864332...
Epoch: 22/50... Loss: 0.848013...
Epoch: 23/50... Loss: 0.863113...
Epoch: 24/50... Loss: 0.852868...
Epoch: 25/50... Loss: 0.859522...
Epoch: 26/50... Loss: 0.849938...
Epoch: 27/50... Loss: 0.854970...
Epoch: 28/50... Loss: 0.855755...
Epoch: 29/50... Loss: 0.860601...
Epoch: 30/50... Loss: 0.873638...
Epoch: 31/50... Loss: 0.866688...
Epoch: 32/50... Loss: 0.843742...
Epoch: 33/50... Loss: 0.849657...
Epoch: 34/50... Loss: 0.856069...
Epoch: 35/50... Loss: 0.859831...
Epoch: 36/50... Loss: 0.849212...
Epoch: 37/50... Loss: 0.852001...
Epoch: 38/50... Loss: 0.860625...
Epoch: 39/50... Loss: 0.878451...
Epoch: 40/50... Loss: 0.862692...
Epoch: 41/50... Loss: 0.856125...
Epoch: 42/50... Loss: 0.865457...
Epoch: 43/50... Loss: 0.866801...
Epoch: 44/50... Loss: 0.854730...
Epoch: 45/50... Loss: 0.854684...
Epoch: 46/50... Loss: 0.862919...
Epoch: 47/50... Loss: 0.858217...
Epoch: 48/50... Loss: 0.854350...
Epoch: 49/50... Loss: 0.864497...
Epoch: 50/50... Loss: 0.863351...

You only have 1 linear layer after the nn.LSTM layer. So that linear layer is your output layer, and there should be no further activation layers such as ReLU or Dropout layers. Try:

def forward(self, x, hidden):
    embeds = self.embedding(x)
    out, hidden = self.lstm(embeds, hidden)
    # fully connected
    out = self.linear(out[:, -1, :])
    # softmax
    out = self.softmax(out)
    return out, hidden

You also might want to call hidden = net.init_hidden() for each batch, not just for each epoch. While there are other means, be default the hidden state is initialized for each batch:

for epoch in range(epochs):
    temp_loss = []
    for inputs, labels in train_loader:
        hidden = net.init_hidden()
        ...

With this, you also shouldn’t need the line hidden = tuple([each.data for each in hidden]) anyway. This was probably your workaround to get it working anyway :).

Hi Chris,

Thanks for you kind help.

I’ve made two changes:

  1. added one more linear layer and only apply activation to the first linear layer.
  2. move my init_hidden() call inside of each batch.
    However, the loss still doesn’t make too much sense. It seems that I could either say it doesn’t converge, or it converges right after the first epoch even if I changed learning_rate to be fairly small.
Epoch: 1/50... Loss: 1.055414...
Epoch: 2/50... Loss: 0.912548...
Epoch: 3/50... Loss: 0.910565...
Epoch: 4/50... Loss: 0.908848...
Epoch: 5/50... Loss: 0.908775...
Epoch: 6/50... Loss: 0.910205...
Epoch: 7/50... Loss: 0.909043...
Epoch: 8/50... Loss: 0.908118...
Epoch: 9/50... Loss: 0.908512...
Epoch: 10/50... Loss: 0.909467...
Epoch: 11/50... Loss: 0.908574...
Epoch: 12/50... Loss: 0.910017...
Epoch: 13/50... Loss: 0.908047...
Epoch: 14/50... Loss: 0.909087...
Epoch: 15/50... Loss: 0.908903...
Epoch: 16/50... Loss: 0.909994...
Epoch: 17/50... Loss: 0.908831...
Epoch: 18/50... Loss: 0.907960...
Epoch: 19/50... Loss: 0.908065...
Epoch: 20/50... Loss: 0.908613...
Epoch: 21/50... Loss: 0.908029...
Epoch: 22/50... Loss: 0.908592...
Epoch: 23/50... Loss: 0.909221...
Epoch: 24/50... Loss: 0.909286...
Epoch: 25/50... Loss: 0.908711...
Epoch: 26/50... Loss: 0.909153...
Epoch: 27/50... Loss: 0.908208...
Epoch: 28/50... Loss: 0.908390...
Epoch: 29/50... Loss: 0.909025...
Epoch: 30/50... Loss: 0.908518...
Epoch: 31/50... Loss: 0.908976...
Epoch: 32/50... Loss: 0.907897...
Epoch: 33/50... Loss: 0.908882...
Epoch: 34/50... Loss: 0.908293...
Epoch: 35/50... Loss: 0.910187...
Epoch: 36/50... Loss: 0.908621...
Epoch: 37/50... Loss: 0.908641...
Epoch: 38/50... Loss: 0.908749...
Epoch: 39/50... Loss: 0.907741...
Epoch: 40/50... Loss: 0.909765...
Epoch: 41/50... Loss: 0.907573...
Epoch: 42/50... Loss: 0.908850...
Epoch: 43/50... Loss: 0.910151...
Epoch: 44/50... Loss: 0.908220...
Epoch: 45/50... Loss: 0.908110...
Epoch: 46/50... Loss: 0.907462...
Epoch: 47/50... Loss: 0.908315...
Epoch: 48/50... Loss: 0.907731...
Epoch: 49/50... Loss: 0.908232...
Epoch: 50/50... Loss: 0.908183...

In my previous posts I posted 2 versions of implementations, which should be quite similar. On the other version after I made the changes, I got all Loss become nan.

I wonder if it is because my dataset isn’t suitable for training using LSTM? But in my understanding, even if the training performance isn’t good, the loss should converge if my implementation is correct.

I’m looking forward to receive more advice.

Thanks!