I have a simple LSTM where I want to get the gradient of hidden state at particular time step with respect to all other time steps. A simple example as follows :
x = torch.randn(5, 4, 2)
lstm = torch.nn.LSTM(2, 2, batch_first=True, bidirectional=True)
h, (_, _) = lstm(x)
h.retain_grad()
h[:, 1].sum().backward()
print(h.grad)
Output :
tensor([[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
This only return me gradient of hidden state at current time step (1 in this case) wrt to that time step only which is 1 ofcourse. But how do I get the gradient wrt to other time steps ?