# Gradient of hidden state in LSTM

I have a simple LSTM where I want to get the gradient of hidden state at particular time step with respect to all other time steps. A simple example as follows :

``````x = torch.randn(5, 4, 2)
lstm = torch.nn.LSTM(2, 2, batch_first=True, bidirectional=True)

h, (_, _) = lstm(x)
h.retain_grad()

h[:, 1].sum().backward()
print(h.grad)
``````

Output :

``````tensor([[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
``````

This only return me gradient of hidden state at current time step (1 in this case) wrt to that time step only which is 1 ofcourse. But how do I get the gradient wrt to other time steps ?

For this you need to spell out the time loop using LSTMCell. That makes you pass the hidden state at each timestep and you can use `hx.retain_grad()` etc.

Best regards

Thomas