Lstm.weight_hh_l0 doesn't update after gradient descent (but the other parameters do update!)

I was trying to write some tests for my model using the concept from this blog post that checks if parameters change after performing a gradient descent step. However, when I tried doing this, I found that some parameters wouldn’t change.

I have reproduced a simple example below where lstm.weight_ih_l0, lstm.bias_ih_l0, lstm.bias_hh_l0, fc.weight, fc.bias change after performing gradient descent, while lstm.weight_hh_l0 does not change.

Shouldn’t all of the trainable parameters be updated after every gradient descent step? Is there something wrong with the example code?

#!/usr/bin/env python3
# torch==0.4.0
# numpy==1.15.0

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(MyLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        return out


model = MyLSTM(input_size=5, hidden_size=8, num_layers=1, num_classes=1)
optimizer = optim.Adagrad(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

params = list(model.state_dict().keys())
assert len(params) == len(list(model.parameters()))
# ['lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0',
# 'lstm.bias_hh_l0', 'fc.weight', 'fc.bias']

# Create some dummy data
inputs = torch.FloatTensor([
    [0.2, 0.3, 0.4, 0.5, 0.6],
    [0.2, 0.3, 0.4, 0.5, 0.6],
    [0.2, 0.3, 0.4, 0.5, 0.6],
]).unsqueeze(1)
target = torch.FloatTensor([1, 0, 1]).unsqueeze(1)

# Save parameter values before gradient descent
before = [model.state_dict()[param].clone() for param in params]

output = model(inputs)
loss = criterion(output, target)
loss.backward()
optimizer.step()

# Save parameter values after gradient descent
after = [model.state_dict()[param].clone() for param in params]

for i, (b, a) in enumerate(zip(before, after)):
    # lstm.weight_hh_l0 doesn't change, but all the other ones do
    # (including lstm.bias_hh_l0!)
    if np.allclose(b, a):
        print('Unchanged:', params[i])
1 Like

If you check the gradient of the parameters:

for p in model.parameters():
    print(p.grad)

You can see the gradients for one of the parameters are all 0 which means that variable will not get updated during back pass.

You need to check why the gradients are 0s for that variable.

I have a similar LSTM model for imdb sentiment classification. Noticed that one of the LSTM parameters doesn’t change at all.

# trainable parameters
['embedding.weight', 'lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0', 'linear.weight', 'linear.bias']
# unchanged parameter
lstm.weight_ih_l0 

Edit : I’ve checked the gradients. They are very small for lstm.weight_ih_l0 but not zero, in the range of e^-42.

1 Like

Any update on this ?

Gradients do change, they are just very small.

I have the same issue. Gradients for weight_hh_l0 in my 1 layer LSTM network are just zeros. I have checked if they are really zeros, and yes, they are simply 0s. Throughout the backward pass, they never get updated. What could be the reason for such behaviour?