I think there is a type for the LSTMCell example given here at api doc page: https://pytorch.org/docs/stable/generated/torch.nn.LSTMCell.html
Current snippet is:
>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
I think the input rather should be:
input = torch.randn(6, 3, 10)
Could someone please confirm these?