Gradient of hidden state in LSTM

I have a simple LSTM where I want to get the gradient of hidden state at particular time step with respect to all other time steps. A simple example as follows :

x = torch.randn(5, 4, 2)
lstm = torch.nn.LSTM(2, 2, batch_first=True, bidirectional=True)

h, (_, _) = lstm(x)
h.retain_grad()

h[:, 1].sum().backward()
print(h.grad)

Output :

tensor([[[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

This only return me gradient of hidden state at current time step (1 in this case) wrt to that time step only which is 1 ofcourse. But how do I get the gradient wrt to other time steps ?

For this you need to spell out the time loop using LSTMCell. That makes you pass the hidden state at each timestep and you can use hx.retain_grad() etc.

Best regards

Thomas