Why LSTM stops learning if I do not set a hidden state to zero?

If I remove 2 lines h0=torch.zeros.. c0=torch.zeros and batch_first=True my network stops learning.

I thought that a zero initial hidden state is by default in nn.LSTM if you don’t pass in a hidden state .

class ModelLSTMFSM(nn.Module):
    def __init__(self, input_size=MAX_STRING_SIZE, hidden_size=256, num_layers=2, states_size=MAX_STATES_SIZE):
        super(ModelLSTMFSM, self).__init__()
        self.states_size = states_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, self.states_size * self.states_size * 2)
    
    def forward(self, x):
        x = x.reshape(-1, INPUT_SIZE, MAX_STRING_SIZE)
        
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out.reshape(-1, self.states_size, self.states_size, 2)

If your remove batch_first=True it’s of course batch_first=False by default. In this case you would need to change out = self.fc(out[:, -1, :]) to out = self.fc(out[-1])

I don’t what you’re trying to learn and how your data looks like, but x = x.reshape(-1, INPUT_SIZE, MAX_STRING_SIZE) looks a but suspicious. Firstly, why do you need to infer the batch size (I assume batch_first=True and the first dimension is for the batch). And secondly, input_size is expected to be in the 3rd dimension of the input tensors for a LSTM/GRU.

I’m not sure about your reshape() in the return line either. If you’re not careful reshape() and view() can quickly mess up your data.

1 Like

Thank you for your suggestions. I did code refactoring and have 2 simple models: with batch_first=True trains and batch_first=False does not train (no errors, just a loss is not decreasing). What can be wrong?

Seems that there is no issue with zero hidden state.

I use CrossEntropyLoss. Target shape is (states_size, 2)

class ModelLSTMFSM_TRAINS(nn.Module):
    def __init__(self, input_size=MAX_STRING_SIZE, hidden_size=256, num_layers=2, states_size=MAX_STATES_SIZE):
        super(ModelLSTMFSM_TRAINS, self).__init__()
        self.states_size = states_size
        self.input_size = input_size
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, self.states_size * self.states_size * 2)
    
    def forward(self, x):
        x = x.view(x.size(0), -1, self.input_size)
        
        out, _ = self.lstm(x)
        
        out = self.fc(out[:, -1, :])
    
        return out.view(out.size(0), self.states_size, self.states_size, 2)
    
    
class ModelLSTMFSM_NOT_TRAINS(nn.Module):
    def __init__(self, input_size=MAX_STRING_SIZE, hidden_size=256, num_layers=2, states_size=MAX_STATES_SIZE):
        super(ModelLSTMFSM_NOT_TRAINS, self).__init__()
        self.states_size = states_size
        self.input_size = input_size
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, self.states_size * self.states_size * 2)
    
    def forward(self, x):
        x = x.view(-1, x.size(0), self.input_size)
        
        out, _ = self.lstm(x)

        out = self.fc(out[-1])
        
        return out.view(out.size(0), self.states_size, self.states_size, 2)

Have you initialized the hidden state of your lstm layer? Also you probably need to get the hidden state at each iteration of your training and use it in the next forward() call, otherwise you might be starting from a new hidden state at each iteration (and your model will not converge)

I would suggest to take a look at the word language model example at the pytorch github

In my previous comment, I have provided 2 examples of NN. The difference only in batch_first=False/True way of initialization. One is learning but the other one is not.

Without knowing what exactly you’re trying to learn, it’s basically impossible to tell for sure what is off. Also, how does your input x look like, i.e., what is the kind of data and what is its shape. At the moment I can only make the vaguely informed guess that your view() commands cause the problem. Any need for something like x.view(batch_size, ...) should not be needed and only does the right thing by chance.

Debugging neural nets can be tricky:

  • Errors are only most commonly thrown when the shapes do not match the expected input. This is the easy case since, well, nothing works :slight_smile:

  • If there are no errors but the network is not training, i.e., the loss is not going down, then there can be many issues, but at east there’s an indicator that something is wrong.

  • However, even if the loss does goes down doesn’t mean that the training is done correctly!!! The network only looks for pattern which it can fined even in messed up data/batches. And this can easily happen when using view() or reshape() when one is not being careful.

I would almost betting 5 bucks that your network that is learning something, is still not necessarily learning correctly. But again, without knowing what you’re trying to do and what your input x looks like, it’s impossible to tell for sure.

For both networks can you change your forward methods to:

def forward(self, x):
    print(x.shape)
    x = x.view(...)
    print(x.shape)
    ...

and report the output of the print statements? Also, what do the dimensions of the input x mean?

1 Like

You are right. There is no need for reshaping in this model (ModelLSTMFSM_TRAINS).

trains:

before torch.Size([64, 21, 8])
after torch.Size([64, 21, 8])

not trains:

before torch.Size([64, 21, 8])
after torch.Size([21, 64, 8])

batch_size=64

x is a matrix, with 21 rows. The size of each row is 8. Each row value can be 0 or 1 and padded with 2. Padded binary strings.

tensor([[0., 2., 2., 2., 2., 2., 2., 2.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 2., 2., 2.],
        [0., 0., 1., 1., 0., 1., 0., 0.], ....

The output is 3 by 2 matrix. In my case it is Finite-state machine:
[[0, 1], [2, 1], [2, 1]]

I have found the issue: Do not use view() or reshape() to swap dimensions of tensors!

x = x.view(-1, x.size(0), self.input_size) # very bad idea, do not use it
x = x.permute(1,0,2) # fixes everething

1 Like

Glad it worked out in the end. You’d be surprised how many people use view() or reshape() to squash their tensors into the right shape without double-checking if preserves integrity.

Your post about this issue is super helpful. Many thanks.

What is the right way to check if the data preserves integrity after view() or reshape()?

I don’t really work with PyTorch that often and it’s usually just basic stuff. And the first question is always: Do I really need view() or reshape(). And as an NLP working with RNNs, the answer is basically always No. The only exception is if I need to resolve the output of a, say, GRU, where the hidden state h_n has a shape of (num_layers * num_directions, batch, hidden_size) and I need to “dissolve” num_layers and num_directions with

h_n = h_n.view(num_layers, num_directions, batch, hidden_size) 

But this is well document and I trust the the PyTorch people are smart.

I once had to use it to flatten and unflatten an LSTM/GRU hidden state for an (Variational) Autoencoder. There the important part was to ensure that the tensor for flattening and after unflattening was indeed the same, which I double-checked with a toy tenser to simplify testing. Again, the Github code I found to help did it incorrectly, i.e., the tensor looked different after unflattening.

That’s the problem: once the shapes are correct, the network throws no errors. And even if then tensor is messed up, the network still might learn something.

1 Like